7 #define GET_GROUP_IDX groupIdx.x
8 #define GET_LOCAL_IDX localIdx.x
9 #define GET_GLOBAL_IDX globalIdx.x
10 #define GROUP_LDS_BARRIER GroupMemoryBarrierWithGroupSync()
11 #define GROUP_MEM_FENCE
12 #define DEFAULT_ARGS uint3 globalIdx : SV_DispatchThreadID, uint3 localIdx : SV_GroupThreadID, uint3 groupIdx : SV_GroupID
13 #define AtomInc(x) InterlockedAdd(x, 1)
14 #define AtomInc1(x, out) InterlockedAdd(x, 1, out)
16 #define make_uint4 uint4
17 #define make_uint2 uint2
19 uint4 SELECT_UINT4(uint4 b,uint4 a,uint4 condition ){ return make_uint4( ((condition).x)?a.x:b.x, ((condition).y)?a.y:b.y, ((condition).z)?a.z:b.z, ((condition).w)?a.w:b.w ); }
25 #define GET_GROUP_SIZE WG_SIZE
33 cbuffer SortCB : register( b0 )
40 #define BITS_PER_PASS 4
43 uint4 prefixScanVector( uint4 data )
52 uint prefixScanVectorEx( inout uint4 data )
66 RWStructuredBuffer<SortData> sortDataIn : register( u0 );
67 RWStructuredBuffer<u32> ldsHistogramOut0 : register( u1 );
68 RWStructuredBuffer<u32> ldsHistogramOut1 : register( u2 );
70 groupshared u32 ldsSortData[ WG_SIZE*NUM_PER_WI + 16 ];
73 uint4 localPrefixSum128V( uint4 pData, uint lIdx, inout uint totalSum )
76 ldsSortData[lIdx] = 0;
77 ldsSortData[lIdx+WG_SIZE] = prefixScanVectorEx( pData );
83 int idx = 2*lIdx + (WG_SIZE+1);
86 ldsSortData[idx] += ldsSortData[idx-1];
88 ldsSortData[idx] += ldsSortData[idx-2];
90 ldsSortData[idx] += ldsSortData[idx-4];
92 ldsSortData[idx] += ldsSortData[idx-8];
94 ldsSortData[idx] += ldsSortData[idx-16];
96 ldsSortData[idx] += ldsSortData[idx-32];
98 ldsSortData[idx] += ldsSortData[idx-64];
101 ldsSortData[idx-1] += ldsSortData[idx-2];
108 totalSum = ldsSortData[WG_SIZE*2-1];
109 uint addValue = ldsSortData[lIdx+127];
110 return pData + make_uint4(addValue, addValue, addValue, addValue);
113 void generateHistogram(u32 lIdx, u32 wgIdx,
116 if( lIdx < (1<<BITS_PER_PASS) )
118 ldsSortData[lIdx] = 0;
121 int mask = ((1<<BITS_PER_PASS)-1);
122 uint4 keys = make_uint4( (sortedData.x)&mask, (sortedData.y)&mask, (sortedData.z)&mask, (sortedData.w)&mask );
126 AtomInc( ldsSortData[keys.x] );
127 AtomInc( ldsSortData[keys.y] );
128 AtomInc( ldsSortData[keys.z] );
129 AtomInc( ldsSortData[keys.w] );
132 [numthreads(WG_SIZE, 1, 1)]
133 void LocalSortKernel( DEFAULT_ARGS )
135 int nElemsPerWG = WG_SIZE*NUM_PER_WI;
136 u32 lIdx = GET_LOCAL_IDX;
137 u32 wgIdx = GET_GROUP_IDX;
138 u32 wgSize = GET_GROUP_SIZE;
140 uint4 localAddr = make_uint4(lIdx*4+0,lIdx*4+1,lIdx*4+2,lIdx*4+3);
143 SortData sortData[NUM_PER_WI];
146 u32 offset = nElemsPerWG*wgIdx;
147 sortData[0] = sortDataIn[offset+localAddr.x];
148 sortData[1] = sortDataIn[offset+localAddr.y];
149 sortData[2] = sortDataIn[offset+localAddr.z];
150 sortData[3] = sortDataIn[offset+localAddr.w];
153 int bitIdx = m_startBit;
157 // if( lIdx == wgSize-1 ) ldsSortData[256] = sortData[3].m_key;
158 u32 mask = (1<<bitIdx);
159 uint4 cmpResult = make_uint4( sortData[0].m_key & mask, sortData[1].m_key & mask, sortData[2].m_key & mask, sortData[3].m_key & mask );
160 uint4 prefixSum = SELECT_UINT4( make_uint4(1,1,1,1), make_uint4(0,0,0,0), cmpResult != make_uint4(0,0,0,0) );
162 prefixSum = localPrefixSum128V( prefixSum, lIdx, total );
165 uint4 dstAddr = localAddr - prefixSum + make_uint4( total, total, total, total );
166 dstAddr = SELECT_UINT4( prefixSum, dstAddr, cmpResult != make_uint4(0, 0, 0, 0) );
170 ldsSortData[dstAddr.x] = sortData[0].m_key;
171 ldsSortData[dstAddr.y] = sortData[1].m_key;
172 ldsSortData[dstAddr.z] = sortData[2].m_key;
173 ldsSortData[dstAddr.w] = sortData[3].m_key;
177 sortData[0].m_key = ldsSortData[localAddr.x];
178 sortData[1].m_key = ldsSortData[localAddr.y];
179 sortData[2].m_key = ldsSortData[localAddr.z];
180 sortData[3].m_key = ldsSortData[localAddr.w];
184 ldsSortData[dstAddr.x] = sortData[0].m_value;
185 ldsSortData[dstAddr.y] = sortData[1].m_value;
186 ldsSortData[dstAddr.z] = sortData[2].m_value;
187 ldsSortData[dstAddr.w] = sortData[3].m_value;
191 sortData[0].m_value = ldsSortData[localAddr.x];
192 sortData[1].m_value = ldsSortData[localAddr.y];
193 sortData[2].m_value = ldsSortData[localAddr.z];
194 sortData[3].m_value = ldsSortData[localAddr.w];
200 while( bitIdx <(m_startBit+BITS_PER_PASS) );
202 { // generate historgram
203 uint4 localKeys = make_uint4( sortData[0].m_key>>m_startBit, sortData[1].m_key>>m_startBit,
204 sortData[2].m_key>>m_startBit, sortData[3].m_key>>m_startBit );
206 generateHistogram( lIdx, wgIdx, localKeys );
210 int nBins = (1<<BITS_PER_PASS);
213 u32 histValues = ldsSortData[lIdx];
215 u32 globalAddresses = nBins*wgIdx + lIdx;
216 u32 globalAddressesRadixMajor = m_numGroups*lIdx + wgIdx;
218 ldsHistogramOut0[globalAddressesRadixMajor] = histValues;
219 ldsHistogramOut1[globalAddresses] = histValues;
224 u32 offset = nElemsPerWG*wgIdx;
225 uint4 dstAddr = make_uint4(offset+localAddr.x, offset+localAddr.y, offset+localAddr.z, offset+localAddr.w );
227 sortDataIn[ dstAddr.x + 0 ] = sortData[0];
228 sortDataIn[ dstAddr.x + 1 ] = sortData[1];
229 sortDataIn[ dstAddr.x + 2 ] = sortData[2];
230 sortDataIn[ dstAddr.x + 3 ] = sortData[3];
234 StructuredBuffer<SortData> src : register( t0 );
235 StructuredBuffer<u32> histogramGlobalRadixMajor : register( t1 );
236 StructuredBuffer<u32> histogramLocalGroupMajor : register( t2 );
238 RWStructuredBuffer<SortData> dst : register( u0 );
240 groupshared u32 ldsLocalHistogram[ 2*(1<<BITS_PER_PASS) ];
241 groupshared u32 ldsGlobalHistogram[ (1<<BITS_PER_PASS) ];
244 [numthreads(WG_SIZE, 1, 1)]
245 void ScatterKernel( DEFAULT_ARGS )
247 u32 lIdx = GET_LOCAL_IDX;
248 u32 wgIdx = GET_GROUP_IDX;
249 u32 ldsOffset = (1<<BITS_PER_PASS);
251 // load and prefix scan local histogram
252 if( lIdx < ((1<<BITS_PER_PASS)/2) )
254 uint2 myIdx = make_uint2(lIdx, lIdx+8);
256 ldsLocalHistogram[ldsOffset+myIdx.x] = histogramLocalGroupMajor[(1<<BITS_PER_PASS)*wgIdx + myIdx.x];
257 ldsLocalHistogram[ldsOffset+myIdx.y] = histogramLocalGroupMajor[(1<<BITS_PER_PASS)*wgIdx + myIdx.y];
258 ldsLocalHistogram[ldsOffset+myIdx.x-(1<<BITS_PER_PASS)] = 0;
259 ldsLocalHistogram[ldsOffset+myIdx.y-(1<<BITS_PER_PASS)] = 0;
261 int idx = ldsOffset+2*lIdx;
262 ldsLocalHistogram[idx] += ldsLocalHistogram[idx-1];
264 ldsLocalHistogram[idx] += ldsLocalHistogram[idx-2];
266 ldsLocalHistogram[idx] += ldsLocalHistogram[idx-4];
268 ldsLocalHistogram[idx] += ldsLocalHistogram[idx-8];
271 // Propagate intermediate values through
272 ldsLocalHistogram[idx-1] += ldsLocalHistogram[idx-2];
275 // Grab and propagate for whole WG - loading the - 1 value
277 localValues.x = ldsLocalHistogram[ldsOffset+myIdx.x-1];
278 localValues.y = ldsLocalHistogram[ldsOffset+myIdx.y-1];
280 ldsLocalHistogram[myIdx.x] = localValues.x;
281 ldsLocalHistogram[myIdx.y] = localValues.y;
284 ldsGlobalHistogram[myIdx.x] = histogramGlobalRadixMajor[m_numGroups*myIdx.x + wgIdx];
285 ldsGlobalHistogram[myIdx.y] = histogramGlobalRadixMajor[m_numGroups*myIdx.y + wgIdx];
290 uint4 localAddr = make_uint4(lIdx*4+0,lIdx*4+1,lIdx*4+2,lIdx*4+3);
292 SortData sortData[4];
294 uint4 globalAddr = wgIdx*WG_SIZE*NUM_PER_WI + localAddr;
295 sortData[0] = src[globalAddr.x];
296 sortData[1] = src[globalAddr.y];
297 sortData[2] = src[globalAddr.z];
298 sortData[3] = src[globalAddr.w];
301 uint cmpValue = ((1<<BITS_PER_PASS)-1);
302 uint4 radix = make_uint4( (sortData[0].m_key>>m_startBit)&cmpValue, (sortData[1].m_key>>m_startBit)&cmpValue,
303 (sortData[2].m_key>>m_startBit)&cmpValue, (sortData[3].m_key>>m_startBit)&cmpValue );;
305 // data is already sorted. So simply subtract local prefix sum
307 dstAddr.x = ldsGlobalHistogram[radix.x] + (localAddr.x - ldsLocalHistogram[radix.x]);
308 dstAddr.y = ldsGlobalHistogram[radix.y] + (localAddr.y - ldsLocalHistogram[radix.y]);
309 dstAddr.z = ldsGlobalHistogram[radix.z] + (localAddr.z - ldsLocalHistogram[radix.z]);
310 dstAddr.w = ldsGlobalHistogram[radix.w] + (localAddr.w - ldsLocalHistogram[radix.w]);
312 dst[dstAddr.x] = sortData[0];
313 dst[dstAddr.y] = sortData[1];
314 dst[dstAddr.z] = sortData[2];
315 dst[dstAddr.w] = sortData[3];
318 [numthreads(WG_SIZE, 1, 1)]
319 void CopyKernel( DEFAULT_ARGS )
321 dst[ GET_GLOBAL_IDX ] = src[ GET_GLOBAL_IDX ];