2 * Copyright 1993-2006 NVIDIA Corporation. All rights reserved.
6 * This source code is subject to NVIDIA ownership rights under U.S. and
7 * international Copyright laws.
9 * NVIDIA MAKES NO REPRESENTATION ABOUT THE SUITABILITY OF THIS SOURCE
10 * CODE FOR ANY PURPOSE. IT IS PROVIDED "AS IS" WITHOUT EXPRESS OR
11 * IMPLIED WARRANTY OF ANY KIND. NVIDIA DISCLAIMS ALL WARRANTIES WITH
12 * REGARD TO THIS SOURCE CODE, INCLUDING ALL IMPLIED WARRANTIES OF
13 * MERCHANTABILITY, NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
14 * IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL,
15 * OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
16 * OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
17 * OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE
18 * OR PERFORMANCE OF THIS SOURCE CODE.
20 * U.S. Government End Users. This source code is a "commercial item" as
21 * that term is defined at 48 C.F.R. 2.101 (OCT 1995), consisting of
22 * "commercial computer software" and "commercial computer software
23 * documentation" as such terms are used in 48 C.F.R. 12.212 (SEPT 1995)
24 * and is provided to the U.S. Government only as a commercial end item.
25 * Consistent with 48 C.F.R.12.212 and 48 C.F.R. 227.7202-1 through
26 * 227.7202-4 (JUNE 1995), all U.S. Government End Users acquire the
27 * source code with only those rights set forth herein.
30 /* Radixsort project with key/value and arbitrary datset size support
31 * which demonstrates the use of CUDA in a multi phase sorting
36 #ifndef _RADIXSORT_KERNEL_H_
37 #define _RADIXSORT_KERNEL_H_
40 #include "radixsort.cuh"
42 #define SYNCIT __syncthreads()
44 static const int NUM_SMS = 16;
45 static const int NUM_THREADS_PER_SM = 192;
46 static const int NUM_THREADS_PER_BLOCK = 64;
47 //static const int NUM_THREADS = NUM_THREADS_PER_SM * NUM_SMS;
48 static const int NUM_BLOCKS = (NUM_THREADS_PER_SM / NUM_THREADS_PER_BLOCK) * NUM_SMS;
49 static const int RADIX = 8; // Number of bits per radix sort pass
50 static const int RADICES = 1 << RADIX; // Number of radices
51 static const int RADIXMASK = RADICES - 1; // Mask for each radix sort pass
53 static const int RADIXBITS = 16; // Number of bits to sort over
55 static const int RADIXBITS = 32; // Number of bits to sort over
57 static const int RADIXTHREADS = 16; // Number of threads sharing each radix counter
58 static const int RADIXGROUPS = NUM_THREADS_PER_BLOCK / RADIXTHREADS; // Number of radix groups per CTA
59 static const int TOTALRADIXGROUPS = NUM_BLOCKS * RADIXGROUPS; // Number of radix groups for each radix
60 static const int SORTRADIXGROUPS = TOTALRADIXGROUPS * RADICES; // Total radix count
61 static const int GRFELEMENTS = (NUM_THREADS_PER_BLOCK / RADIXTHREADS) * RADICES;
62 static const int GRFSIZE = GRFELEMENTS * sizeof(uint);
64 // Prefix sum variables
65 static const int PREFIX_NUM_THREADS_PER_SM = NUM_THREADS_PER_SM;
66 static const int PREFIX_NUM_THREADS_PER_BLOCK = PREFIX_NUM_THREADS_PER_SM;
67 static const int PREFIX_NUM_BLOCKS = (PREFIX_NUM_THREADS_PER_SM / PREFIX_NUM_THREADS_PER_BLOCK) * NUM_SMS;
68 static const int PREFIX_BLOCKSIZE = SORTRADIXGROUPS / PREFIX_NUM_BLOCKS;
69 static const int PREFIX_GRFELEMENTS = PREFIX_BLOCKSIZE + 2 * PREFIX_NUM_THREADS_PER_BLOCK;
70 static const int PREFIX_GRFSIZE = PREFIX_GRFELEMENTS * sizeof(uint);
73 static const int SHUFFLE_GRFOFFSET = RADIXGROUPS * RADICES;
74 static const int SHUFFLE_GRFELEMENTS = SHUFFLE_GRFOFFSET + PREFIX_NUM_BLOCKS;
75 static const int SHUFFLE_GRFSIZE = SHUFFLE_GRFELEMENTS * sizeof(uint);
78 #define SDATA( index) CUT_BANK_CHECKER(sdata, index)
81 uint gRadixSum[TOTALRADIXGROUPS * RADICES];
82 __device__ uint dRadixSum[TOTALRADIXGROUPS * RADICES];
83 uint gRadixBlockSum[PREFIX_NUM_BLOCKS];
84 __device__ uint dRadixBlockSum[PREFIX_NUM_BLOCKS];
86 extern __shared__ uint sRadixSum[];
90 ////////////////////////////////////////////////////////////////////////////////
91 //! Perform a radix sum on the list to be sorted. Each SM holds a set of
92 //! radix counters for each group of RADIXGROUPS thread in the GRF.
94 //! @param pData input data
95 //! @param elements total number of elements
96 //! @param elements_rounded_to_3072 total number of elements rounded up to the
97 //! nearest multiple of 3072
98 //! @param shift the shift (0 to 24) that we are using to obtain the correct
100 ////////////////////////////////////////////////////////////////////////////////
101 __global__ void RadixSum(KeyValuePair *pData, uint elements, uint elements_rounded_to_3072, uint shift)
103 uint pos = threadIdx.x;
106 while (pos < GRFELEMENTS)
109 pos += NUM_THREADS_PER_BLOCK;
113 // Source addresses computed so that each thread is reading from a block of
114 // consecutive addresses so there are no conflicts between threads
115 // They then loop over their combined region and the next batch works elsewhere.
116 // So threads 0 to 16 work on memory 0 to 320.
117 // First reading 0,1,2,3...15 then 16,17,18,19...31 and so on
118 // optimising parallel access to shared memory by a thread accessing 16*threadID
119 // The next radix group runs from 320 to 640 and the same applies in that region
120 uint tmod = threadIdx.x % RADIXTHREADS;
121 uint tpos = threadIdx.x / RADIXTHREADS;
123 // Take the rounded element list size so that all threads have a certain size dataset to work with
124 // and no zero size datasets confusing the issue
125 // By using a multiple of 3072 we ensure that all threads have elements
126 // to work with until the last phase, at which point we individually test
127 uint element_fraction = elements_rounded_to_3072 / TOTALRADIXGROUPS;
130 // Note that it is possible for both pos and end to be past the end of the element set
131 // which will be caught later.
132 pos = (blockIdx.x * RADIXGROUPS + tpos) * element_fraction;
133 uint end = pos + element_fraction;
135 //printf("pos: %d\n", pos);
142 // Read first data element if we are in the set of elements
143 //if( pos < elements )
144 //key = pData[pos].key;
146 // Read first data element, both items at once as the memory will want to coalesce like that anyway
154 // Calculate position of radix counter to increment
155 // There are RADICES radices in each pass (256)
156 // and hence this many counters for bin grouping
157 // Multiply by RADIXGROUPS (4) to spread through memory
158 // and into 4 radix groups
159 uint p = ((key >> shift) & RADIXMASK) * RADIXGROUPS;
161 // Increment radix counters
162 // Each radix group has its own set of counters
163 // so we add the thread position [0-3], ie the group index.
164 // We slow down here and take at least 16 cycles to write to the summation boxes
165 // but other groups will only conflict with themselves and so can also be writing
166 // 16 cycles here at least avoids retries.
167 uint ppos = p + tpos;
169 // If we are past the last element we don't want to do anything
170 // We do have to check each time, however, to ensure that all
171 // threads sync on each sync here.
172 if (tmod == 0 && pos < elements)
175 if (tmod == 1 && pos < elements)
178 if (tmod == 2 && pos < elements)
181 if (tmod == 3 && pos < elements)
184 if (tmod == 4 && pos < elements)
187 if (tmod == 5 && pos < elements)
190 if (tmod == 6 && pos < elements)
193 if (tmod == 7 && pos < elements)
196 if (tmod == 8 && pos < elements)
199 if (tmod == 9 && pos < elements)
202 if (tmod == 10 && pos < elements)
205 if (tmod == 11 && pos < elements)
208 if (tmod == 12 && pos < elements)
211 if (tmod == 13 && pos < elements)
214 if (tmod == 14 && pos < elements)
217 if (tmod == 15 && pos < elements)
229 // Output radix sums into separate memory regions for each radix group
230 // So this memory then is layed out:
231 // 0...... 192..... 384 ................ 192*256
232 // ie all 256 bins for each radix group
235 // 0 4 8 12... - block idx * 4
236 // And in the block boxes we see the 4 radix groups for that block
237 // So 0-192 should contain bin 0 for each radix group, and so on
238 uint offset = blockIdx.x * RADIXGROUPS;
239 uint row = threadIdx.x / RADIXGROUPS;
240 uint column = threadIdx.x % RADIXGROUPS;
241 while (row < RADICES)
243 dRadixSum[offset + row * TOTALRADIXGROUPS + column] = sRadixSum[row * RADIXGROUPS + column];
244 row += NUM_THREADS_PER_BLOCK / RADIXGROUPS;
248 ////////////////////////////////////////////////////////////////////////////////
249 //! Performs first part of parallel prefix sum - individual sums of each radix
250 //! count. By the end of this we have prefix sums on a block level in dRadixSum
251 //! and totals for blocks in dRadixBlockSum.
252 ////////////////////////////////////////////////////////////////////////////////
253 __global__ void RadixPrefixSum()
255 // Read radix groups in offset by one in the GRF so a zero can be inserted at the beginning
256 // and the final sum of all radix counts summed here is tacked onto the end for reading by
258 // Each block in this case is the full number of threads per SM (and hence the total number
259 // of radix groups), 192. We should then have the total set of offsets for an entire radix
260 // group by the end of this stage
261 // Device mem addressing
263 uint brow = blockIdx.x * (RADICES / PREFIX_NUM_BLOCKS);
264 uint drow = threadIdx.x / TOTALRADIXGROUPS; // In default parameterisation this is always 0
265 uint dcolumn = threadIdx.x % TOTALRADIXGROUPS; // And similarly this is always the same as threadIdx.x
266 uint dpos = (brow + drow) * TOTALRADIXGROUPS + dcolumn;
267 uint end = ((blockIdx.x + 1) * (RADICES / PREFIX_NUM_BLOCKS)) * TOTALRADIXGROUPS;
268 // Shared mem addressing
269 uint srow = threadIdx.x / (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK);
270 uint scolumn = threadIdx.x % (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK);
271 uint spos = srow * (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK + 1) + scolumn;
273 // Read (RADICES / PREFIX_NUM_BLOCKS) radix counts into the GRF alongside each other
276 sRadixSum[spos] = dRadixSum[dpos];
277 spos += (PREFIX_NUM_THREADS_PER_BLOCK / (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK)) *
278 (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK + 1);
279 dpos += (TOTALRADIXGROUPS / PREFIX_NUM_THREADS_PER_BLOCK) * TOTALRADIXGROUPS;
283 // Perform preliminary sum on each thread's stretch of data
284 // Each thread having a block of 16, with spacers between 0...16 18...33 and so on
285 int pos = threadIdx.x * (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK + 1);
286 end = pos + (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK);
290 sum += sRadixSum[pos];
291 sRadixSum[pos] = sum;
297 // Calculate internal offsets by performing a more traditional parallel
298 // prefix sum of the topmost member of each thread's work data. Right now,
299 // these are stored between the work data for each thread, allowing us to
300 // eliminate GRF conflicts as well as hold the offsets needed to complete the sum
301 // In other words we have:
302 // 0....15 16 17....32 33 34....
303 // Where this first stage updates the intermediate values (so 16=15, 33=32 etc)
304 int m = (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK + 1);
305 pos = threadIdx.x * (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK + 1) +
306 (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK);
307 sRadixSum[pos] = sRadixSum[pos - 1];
309 // This stage then performs a parallel prefix sum (ie use powers of 2 to propagate in log n stages)
310 // to update 17, 34 etc with the totals to that point (so 34 becomes [34] + [17]) and so on.
311 while (m < PREFIX_NUM_THREADS_PER_BLOCK * (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK + 1))
314 uint t = ((p > 0) ? sRadixSum[p] : 0);
324 // Add internal offsets to each thread's work data.
325 // So now we take 17 and add it to all values 18 to 33 so all offsets for that block
327 pos = threadIdx.x * (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK + 1);
328 end = pos + (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK);
330 sum = ((p > 0) ? sRadixSum[p] : 0);
333 sRadixSum[pos] += sum;
338 // Write summed data back out to global memory in the same way as we read it in
339 // We now have prefix sum values internal to groups
340 brow = blockIdx.x * (RADICES / PREFIX_NUM_BLOCKS);
341 drow = threadIdx.x / TOTALRADIXGROUPS;
342 dcolumn = threadIdx.x % TOTALRADIXGROUPS;
343 srow = threadIdx.x / (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK);
344 scolumn = threadIdx.x % (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK);
345 dpos = (brow + drow) * TOTALRADIXGROUPS + dcolumn + 1;
346 spos = srow * (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK + 1) + scolumn;
347 end = ((blockIdx.x + 1) * RADICES / PREFIX_NUM_BLOCKS) * TOTALRADIXGROUPS;
350 dRadixSum[dpos] = sRadixSum[spos];
351 dpos += (TOTALRADIXGROUPS / PREFIX_NUM_THREADS_PER_BLOCK) * TOTALRADIXGROUPS;
352 spos += (PREFIX_NUM_THREADS_PER_BLOCK / (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK)) *
353 (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK + 1);
356 // Write last element to summation
357 // Storing block sums in a separate array
358 if (threadIdx.x == 0) {
359 dRadixBlockSum[blockIdx.x] = sRadixSum[PREFIX_NUM_THREADS_PER_BLOCK * (PREFIX_BLOCKSIZE / PREFIX_NUM_THREADS_PER_BLOCK + 1) - 1];
360 dRadixSum[blockIdx.x * PREFIX_BLOCKSIZE] = 0;
365 ////////////////////////////////////////////////////////////////////////////////
366 //! Initially perform prefix sum of block totals to obtain final set of offsets.
367 //! Then make use of radix sums to perform a shuffling of the data into the
370 //! @param pSrc input data
371 //! @param pDst output data
372 //! @param elements total number of elements
373 //! @param shift the shift (0 to 24) that we are using to obtain the correct
375 ////////////////////////////////////////////////////////////////////////////////
376 __global__ void RadixAddOffsetsAndShuffle(KeyValuePair* pSrc, KeyValuePair* pDst, uint elements, uint elements_rounded_to_3072, int shift)
378 // Read offsets from previous blocks
379 if (threadIdx.x == 0)
380 sRadixSum[SHUFFLE_GRFOFFSET] = 0;
382 if (threadIdx.x < PREFIX_NUM_BLOCKS - 1)
383 sRadixSum[SHUFFLE_GRFOFFSET + threadIdx.x + 1] = dRadixBlockSum[threadIdx.x];
386 // Parallel prefix sum over block sums
387 int pos = threadIdx.x;
389 while (n < PREFIX_NUM_BLOCKS)
392 uint t0 = ((pos < PREFIX_NUM_BLOCKS) && (ppos >= 0)) ? sRadixSum[SHUFFLE_GRFOFFSET + ppos] : 0;
394 if (pos < PREFIX_NUM_BLOCKS)
395 sRadixSum[SHUFFLE_GRFOFFSET + pos] += t0;
400 // Read radix count data and add appropriate block offset
401 // for each radix at the memory location for this thread
402 // (where the other threads in the block will be reading
403 // as well, hence the large stride).
404 // There is one counter box per radix group per radix
405 // per block (4*256*3)
406 // We use 64 threads to read the 4 radix groups set of radices
408 int row = threadIdx.x / RADIXGROUPS;
409 int column = threadIdx.x % RADIXGROUPS;
410 int spos = row * RADIXGROUPS + column;
411 int dpos = row * TOTALRADIXGROUPS + column + blockIdx.x * RADIXGROUPS;
412 while (spos < SHUFFLE_GRFOFFSET)
414 sRadixSum[spos] = dRadixSum[dpos] + sRadixSum[SHUFFLE_GRFOFFSET + dpos / (TOTALRADIXGROUPS * RADICES / PREFIX_NUM_BLOCKS)];
415 spos += NUM_THREADS_PER_BLOCK;
416 dpos += (NUM_THREADS_PER_BLOCK / RADIXGROUPS) * TOTALRADIXGROUPS;
422 // Each of the subbins for a block should be filled via the counters, properly interleaved
423 // Then, as we now iterate over each data value, we increment the subbins (each thread in the
424 // radix group in turn to avoid miss writes due to conflicts) and set locations correctly.
425 uint element_fraction = elements_rounded_to_3072 / TOTALRADIXGROUPS;
426 int tmod = threadIdx.x % RADIXTHREADS;
427 int tpos = threadIdx.x / RADIXTHREADS;
429 pos = (blockIdx.x * RADIXGROUPS + tpos) * element_fraction;
430 uint end = pos + element_fraction; //(blockIdx.x * RADIXGROUPS + tpos + 1) * element_fraction;
439 // Read first data element, both items at once as the memory will want to coalesce like that anyway
447 #else // casting to float2 to get it to combine loads
450 // Read first data element, both items at once as the memory will want to coalesce like that anyway
454 kvpf2 = ((int2*)pSrc)[pos];
455 // printf("kvp: %f %f kvpf2: %f %f\n", kvp.key, kvp.value, kvpf2.x, kvpf2.y);
467 // Calculate position of radix counter to increment
468 uint p = ((kvp.key >> shift) & RADIXMASK) * RADIXGROUPS;
470 // Move data, keeping counts updated.
471 // Increment radix counters, relying on hexadecathread
472 // warp to prevent this code from stepping all over itself.
473 uint ppos = p + tpos;
474 if (tmod == 0 && pos < elements)
476 index = sRadixSum[ppos]++;
480 if (tmod == 1 && pos < elements)
482 index = sRadixSum[ppos]++;
486 if (tmod == 2 && pos < elements)
488 index = sRadixSum[ppos]++;
492 if (tmod == 3 && pos < elements)
494 index = sRadixSum[ppos]++;
498 if (tmod == 4 && pos < elements)
500 index = sRadixSum[ppos]++;
504 if (tmod == 5 && pos < elements)
506 index = sRadixSum[ppos]++;
510 if (tmod == 6 && pos < elements)
512 index = sRadixSum[ppos]++;
516 if (tmod == 7 && pos < elements)
518 index = sRadixSum[ppos]++;
522 if (tmod == 8 && pos < elements)
524 index = sRadixSum[ppos]++;
528 if (tmod == 9 && pos < elements)
530 index = sRadixSum[ppos]++;
534 if (tmod == 10 && pos < elements)
536 index = sRadixSum[ppos]++;
540 if (tmod == 11 && pos < elements)
542 index = sRadixSum[ppos]++;
546 if (tmod == 12 && pos < elements)
548 index = sRadixSum[ppos]++;
552 if (tmod == 13 && pos < elements)
554 index = sRadixSum[ppos]++;
558 if (tmod == 14 && pos < elements)
560 index = sRadixSum[ppos]++;
564 if (tmod == 15 && pos < elements)
566 index = sRadixSum[ppos]++;
577 #endif // #ifndef _RADIXSORT_KERNEL_H_