Tizen 2.1 base
[platform/upstream/libbullet.git] / Extras / RigidBodyGpuPipeline / opencl / primitives / AdlPrimitives / Sort / RadixSortStandardKernels.hlsl
1 /*
2                 2011 Takahiro Harada
3 */
4
5 typedef uint u32;
6
7 #define GET_GROUP_IDX groupIdx.x
8 #define GET_LOCAL_IDX localIdx.x
9 #define GET_GLOBAL_IDX globalIdx.x
10 #define GROUP_LDS_BARRIER GroupMemoryBarrierWithGroupSync()
11 #define GROUP_MEM_FENCE
12 #define DEFAULT_ARGS uint3 globalIdx : SV_DispatchThreadID, uint3 localIdx : SV_GroupThreadID, uint3 groupIdx : SV_GroupID
13 #define AtomInc(x) InterlockedAdd(x, 1)
14 #define AtomInc1(x, out) InterlockedAdd(x, 1, out)
15
16 #define make_uint4 uint4
17 #define make_uint2 uint2
18
19 uint4 SELECT_UINT4(uint4 b,uint4 a,uint4 condition ){ return  make_uint4( ((condition).x)?a.x:b.x, ((condition).y)?a.y:b.y, ((condition).z)?a.z:b.z, ((condition).w)?a.w:b.w ); }
20
21 //      takahiro end
22 #define WG_SIZE 128
23 #define NUM_PER_WI 4
24
25 #define GET_GROUP_SIZE WG_SIZE
26
27 typedef struct
28 {
29         u32 m_key; 
30         u32 m_value;
31 }SortData;
32
33 cbuffer SortCB : register( b0 )
34 {
35         u32 m_startBit;
36         u32 m_numGroups;
37         u32 m_padding[2];
38 };
39
40 #define BITS_PER_PASS 4
41
42
43 uint4 prefixScanVector( uint4 data )
44 {
45         data.y += data.x;
46         data.w += data.z;
47         data.z += data.y;
48         data.w += data.y;
49         return data;
50 }
51
52 uint prefixScanVectorEx( inout uint4 data )
53 {
54         uint4 backup = data;
55         data.y += data.x;
56         data.w += data.z;
57         data.z += data.y;
58         data.w += data.y;
59         uint sum = data.w;
60         data -= backup;
61         return sum;
62 }
63
64
65
66 RWStructuredBuffer<SortData> sortDataIn : register( u0 );
67 RWStructuredBuffer<u32> ldsHistogramOut0 : register( u1 );
68 RWStructuredBuffer<u32> ldsHistogramOut1 : register( u2 );
69
70 groupshared u32 ldsSortData[ WG_SIZE*NUM_PER_WI + 16 ];
71
72
73 uint4 localPrefixSum128V( uint4 pData, uint lIdx, inout uint totalSum )
74 {
75         {       //      Set data
76                 ldsSortData[lIdx] = 0;
77                 ldsSortData[lIdx+WG_SIZE] = prefixScanVectorEx( pData );
78         }
79
80         GROUP_LDS_BARRIER;
81
82         {       //      Prefix sum
83                 int idx = 2*lIdx + (WG_SIZE+1);
84                 if( lIdx < 64 )
85                 {
86                         ldsSortData[idx] += ldsSortData[idx-1];
87                         GROUP_MEM_FENCE;
88                         ldsSortData[idx] += ldsSortData[idx-2];                                 
89                         GROUP_MEM_FENCE;
90                         ldsSortData[idx] += ldsSortData[idx-4];
91                         GROUP_MEM_FENCE;
92                         ldsSortData[idx] += ldsSortData[idx-8];
93                         GROUP_MEM_FENCE;
94                         ldsSortData[idx] += ldsSortData[idx-16];
95                         GROUP_MEM_FENCE;
96                         ldsSortData[idx] += ldsSortData[idx-32];                
97                         GROUP_MEM_FENCE;
98                         ldsSortData[idx] += ldsSortData[idx-64];
99                         GROUP_MEM_FENCE;
100
101                         ldsSortData[idx-1] += ldsSortData[idx-2];
102                         GROUP_MEM_FENCE;
103                 }
104         }
105
106         GROUP_LDS_BARRIER;
107
108         totalSum = ldsSortData[WG_SIZE*2-1];
109         uint addValue = ldsSortData[lIdx+127];
110         return pData + make_uint4(addValue, addValue, addValue, addValue);
111 }
112
113 void generateHistogram(u32 lIdx, u32 wgIdx, 
114                 uint4 sortedData)
115 {
116     if( lIdx < (1<<BITS_PER_PASS) )
117     {
118         ldsSortData[lIdx] = 0;
119     }
120
121         int mask = ((1<<BITS_PER_PASS)-1);
122         uint4 keys = make_uint4( (sortedData.x)&mask, (sortedData.y)&mask, (sortedData.z)&mask, (sortedData.w)&mask );
123
124         GROUP_LDS_BARRIER;
125         
126         AtomInc( ldsSortData[keys.x] );
127         AtomInc( ldsSortData[keys.y] );
128         AtomInc( ldsSortData[keys.z] );
129         AtomInc( ldsSortData[keys.w] );
130 }
131
132 [numthreads(WG_SIZE, 1, 1)]
133 void LocalSortKernel( DEFAULT_ARGS )
134 {
135         int nElemsPerWG = WG_SIZE*NUM_PER_WI;
136         u32 lIdx = GET_LOCAL_IDX;
137         u32 wgIdx = GET_GROUP_IDX;
138         u32 wgSize = GET_GROUP_SIZE;
139
140     uint4 localAddr = make_uint4(lIdx*4+0,lIdx*4+1,lIdx*4+2,lIdx*4+3);
141
142
143         SortData sortData[NUM_PER_WI];
144
145         {
146                 u32 offset = nElemsPerWG*wgIdx;
147                 sortData[0] = sortDataIn[offset+localAddr.x];
148                 sortData[1] = sortDataIn[offset+localAddr.y];
149                 sortData[2] = sortDataIn[offset+localAddr.z];
150                 sortData[3] = sortDataIn[offset+localAddr.w];
151         }
152
153         int bitIdx = m_startBit;
154         do
155         {
156 //      what is this?
157 //              if( lIdx == wgSize-1 ) ldsSortData[256] = sortData[3].m_key;
158                 u32 mask = (1<<bitIdx);
159                 uint4 cmpResult = make_uint4( sortData[0].m_key & mask, sortData[1].m_key & mask, sortData[2].m_key & mask, sortData[3].m_key & mask );
160                 uint4 prefixSum = SELECT_UINT4( make_uint4(1,1,1,1), make_uint4(0,0,0,0), cmpResult != make_uint4(0,0,0,0) );
161                 u32 total;
162                 prefixSum = localPrefixSum128V( prefixSum, lIdx, total );
163
164                 {
165                         uint4 dstAddr = localAddr - prefixSum + make_uint4( total, total, total, total );
166                         dstAddr = SELECT_UINT4( prefixSum, dstAddr, cmpResult != make_uint4(0, 0, 0, 0) );
167
168                         GROUP_LDS_BARRIER;
169
170                         ldsSortData[dstAddr.x] = sortData[0].m_key;
171                         ldsSortData[dstAddr.y] = sortData[1].m_key;
172                         ldsSortData[dstAddr.z] = sortData[2].m_key;
173                         ldsSortData[dstAddr.w] = sortData[3].m_key;
174
175                         GROUP_LDS_BARRIER;
176
177                         sortData[0].m_key = ldsSortData[localAddr.x];
178                         sortData[1].m_key = ldsSortData[localAddr.y];
179                         sortData[2].m_key = ldsSortData[localAddr.z];
180                         sortData[3].m_key = ldsSortData[localAddr.w];
181
182                         GROUP_LDS_BARRIER;
183
184                         ldsSortData[dstAddr.x] = sortData[0].m_value;
185                         ldsSortData[dstAddr.y] = sortData[1].m_value;
186                         ldsSortData[dstAddr.z] = sortData[2].m_value;
187                         ldsSortData[dstAddr.w] = sortData[3].m_value;
188
189                         GROUP_LDS_BARRIER;
190
191                         sortData[0].m_value = ldsSortData[localAddr.x];
192                         sortData[1].m_value = ldsSortData[localAddr.y];
193                         sortData[2].m_value = ldsSortData[localAddr.z];
194                         sortData[3].m_value = ldsSortData[localAddr.w];
195
196                         GROUP_LDS_BARRIER;
197                 }
198                 bitIdx ++;
199         }
200         while( bitIdx <(m_startBit+BITS_PER_PASS) );
201
202         {       //      generate historgram
203                 uint4 localKeys = make_uint4( sortData[0].m_key>>m_startBit, sortData[1].m_key>>m_startBit, 
204                         sortData[2].m_key>>m_startBit, sortData[3].m_key>>m_startBit );
205
206                 generateHistogram( lIdx, wgIdx, localKeys );
207
208                 GROUP_LDS_BARRIER;
209
210                 int nBins = (1<<BITS_PER_PASS);
211                 if( lIdx < nBins )
212                 {
213                 u32 histValues = ldsSortData[lIdx];
214
215                 u32 globalAddresses = nBins*wgIdx + lIdx;
216                 u32 globalAddressesRadixMajor = m_numGroups*lIdx + wgIdx;
217                 
218                 ldsHistogramOut0[globalAddressesRadixMajor] = histValues;
219                 ldsHistogramOut1[globalAddresses] = histValues;
220                 }
221         }
222
223         {       //      write
224                 u32 offset = nElemsPerWG*wgIdx;
225                 uint4 dstAddr = make_uint4(offset+localAddr.x, offset+localAddr.y, offset+localAddr.z, offset+localAddr.w );
226
227                 sortDataIn[ dstAddr.x + 0 ] = sortData[0];
228                 sortDataIn[ dstAddr.x + 1 ] = sortData[1];
229                 sortDataIn[ dstAddr.x + 2 ] = sortData[2];
230                 sortDataIn[ dstAddr.x + 3 ] = sortData[3];
231         }
232 }
233
234 StructuredBuffer<SortData> src : register( t0 );
235 StructuredBuffer<u32> histogramGlobalRadixMajor : register( t1 );
236 StructuredBuffer<u32> histogramLocalGroupMajor : register( t2 );
237
238 RWStructuredBuffer<SortData> dst : register( u0 );
239
240 groupshared u32 ldsLocalHistogram[ 2*(1<<BITS_PER_PASS) ];
241 groupshared u32 ldsGlobalHistogram[ (1<<BITS_PER_PASS) ];
242
243
244 [numthreads(WG_SIZE, 1, 1)]
245 void ScatterKernel( DEFAULT_ARGS )
246 {
247         u32 lIdx = GET_LOCAL_IDX;
248         u32 wgIdx = GET_GROUP_IDX;
249         u32 ldsOffset = (1<<BITS_PER_PASS);
250
251         //      load and prefix scan local histogram
252         if( lIdx < ((1<<BITS_PER_PASS)/2) )
253         {
254                 uint2 myIdx = make_uint2(lIdx, lIdx+8);
255
256                 ldsLocalHistogram[ldsOffset+myIdx.x] = histogramLocalGroupMajor[(1<<BITS_PER_PASS)*wgIdx + myIdx.x];
257                 ldsLocalHistogram[ldsOffset+myIdx.y] = histogramLocalGroupMajor[(1<<BITS_PER_PASS)*wgIdx + myIdx.y];
258                 ldsLocalHistogram[ldsOffset+myIdx.x-(1<<BITS_PER_PASS)] = 0;
259                 ldsLocalHistogram[ldsOffset+myIdx.y-(1<<BITS_PER_PASS)] = 0;
260
261                 int idx = ldsOffset+2*lIdx;
262                 ldsLocalHistogram[idx] += ldsLocalHistogram[idx-1];
263                 GROUP_MEM_FENCE;
264                 ldsLocalHistogram[idx] += ldsLocalHistogram[idx-2];
265                 GROUP_MEM_FENCE;
266                 ldsLocalHistogram[idx] += ldsLocalHistogram[idx-4];
267                 GROUP_MEM_FENCE;
268                 ldsLocalHistogram[idx] += ldsLocalHistogram[idx-8];
269                 GROUP_MEM_FENCE;
270
271                 // Propagate intermediate values through
272                 ldsLocalHistogram[idx-1] += ldsLocalHistogram[idx-2];
273                 GROUP_MEM_FENCE;
274
275                 // Grab and propagate for whole WG - loading the - 1 value
276                 uint2 localValues;
277                 localValues.x = ldsLocalHistogram[ldsOffset+myIdx.x-1];
278                 localValues.y = ldsLocalHistogram[ldsOffset+myIdx.y-1];
279
280                 ldsLocalHistogram[myIdx.x] = localValues.x;
281                 ldsLocalHistogram[myIdx.y] = localValues.y;
282
283
284                 ldsGlobalHistogram[myIdx.x] = histogramGlobalRadixMajor[m_numGroups*myIdx.x + wgIdx];
285                 ldsGlobalHistogram[myIdx.y] = histogramGlobalRadixMajor[m_numGroups*myIdx.y + wgIdx];
286         }
287
288         GROUP_LDS_BARRIER;
289
290     uint4 localAddr = make_uint4(lIdx*4+0,lIdx*4+1,lIdx*4+2,lIdx*4+3);
291
292         SortData sortData[4];
293         {
294             uint4 globalAddr = wgIdx*WG_SIZE*NUM_PER_WI + localAddr;
295                 sortData[0] = src[globalAddr.x];
296                 sortData[1] = src[globalAddr.y];
297                 sortData[2] = src[globalAddr.z];
298                 sortData[3] = src[globalAddr.w];
299         }
300
301         uint cmpValue = ((1<<BITS_PER_PASS)-1);
302         uint4 radix = make_uint4( (sortData[0].m_key>>m_startBit)&cmpValue, (sortData[1].m_key>>m_startBit)&cmpValue, 
303                 (sortData[2].m_key>>m_startBit)&cmpValue, (sortData[3].m_key>>m_startBit)&cmpValue );;
304
305         //      data is already sorted. So simply subtract local prefix sum
306         uint4 dstAddr;
307         dstAddr.x = ldsGlobalHistogram[radix.x] + (localAddr.x - ldsLocalHistogram[radix.x]);
308         dstAddr.y = ldsGlobalHistogram[radix.y] + (localAddr.y - ldsLocalHistogram[radix.y]);
309         dstAddr.z = ldsGlobalHistogram[radix.z] + (localAddr.z - ldsLocalHistogram[radix.z]);
310         dstAddr.w = ldsGlobalHistogram[radix.w] + (localAddr.w - ldsLocalHistogram[radix.w]);
311
312         dst[dstAddr.x] = sortData[0];
313         dst[dstAddr.y] = sortData[1];
314         dst[dstAddr.z] = sortData[2];
315         dst[dstAddr.w] = sortData[3];
316 }
317
318 [numthreads(WG_SIZE, 1, 1)]
319 void CopyKernel( DEFAULT_ARGS )
320 {
321         dst[ GET_GLOBAL_IDX ] = src[ GET_GLOBAL_IDX ];
322 }