5 #define PATH "..\\..\\AdlPrimitives\\Sort\\RadixSortAdvancedKernels"
6 #define KERNEL0 "StreamCountKernel"
7 #define KERNEL1 "SortAndScatterKernel1"
8 #define KERNEL2 "PrefixScanKernel"
10 template<DeviceType type>
11 class RadixSortAdvanced : public RadixSortBase
14 typedef Launcher::BufferInfo BufferInfo;
20 MAX_NUM_WORKGROUPS = 60,
23 struct Data : public RadixSort<type>::Data
25 Kernel* m_localCountKernel;
26 Kernel* m_scatterKernel;
29 Buffer<u32>* m_workBuffer0;
30 Buffer<SortData>* m_workBuffer1;
31 Buffer<int4>* m_constBuffer[32/4];
36 Data* allocate(const Device* deviceData, int maxSize, Option option = SORT_NORMAL);
39 void deallocate(void* data);
42 void execute(void* data, Buffer<SortData>& inout, int n, int sortBits);
45 template<DeviceType type>
46 typename RadixSortAdvanced<type>::Data* RadixSortAdvanced<type>::allocate(const Device* deviceData, int maxSize, Option option)
48 ADLASSERT( type == deviceData->m_type );
50 const char* src[] = { 0, 0, 0 };
52 Data* data = new Data;
53 data->m_option = option;
54 data->m_deviceData = deviceData;
56 data->m_localCountKernel = deviceData->getKernel( PATH, KERNEL0, 0, src[type] );
57 data->m_scatterKernel = deviceData->getKernel( PATH, KERNEL1, 0, src[type] );
58 data->m_scanKernel = deviceData->getKernel( PATH, KERNEL2, 0, src[type] );
60 data->m_workBuffer0 = new Buffer<u32>( deviceData, MAX_NUM_WORKGROUPS*16 );
61 data->m_workBuffer1 = new Buffer<SortData>( deviceData, maxSize );
62 for(int i=0; i<32/4; i++)
63 data->m_constBuffer[i] = new Buffer<int4>( deviceData, 1, BufferBase::BUFFER_CONST );
64 data->m_maxSize = maxSize;
69 template<DeviceType type>
70 void RadixSortAdvanced<type>::deallocate(void* rawData)
72 Data* data = (Data*)rawData;
74 delete data->m_workBuffer0;
75 delete data->m_workBuffer1;
76 for(int i=0; i<32/4; i++)
77 delete data->m_constBuffer[i];
82 template<DeviceType type>
83 void RadixSortAdvanced<type>::execute(void* rawData, Buffer<SortData>& inout, int n, int sortBits)
85 Data* data = (Data*)rawData;
87 ADLASSERT( sortBits == 32 );
89 ADLASSERT( NUM_PER_WI == 4 );
90 ADLASSERT( n%(WG_SIZE*NUM_PER_WI) == 0 );
91 ADLASSERT( MAX_NUM_WORKGROUPS < 128*8/16 );
93 Buffer<SortData>* src = &inout;
94 Buffer<SortData>* dst = data->m_workBuffer1;
96 const Device* deviceData = data->m_deviceData;
98 int nBlocks = n/(NUM_PER_WI*WG_SIZE);
99 const int nWorkGroupsToExecute = min2((int)MAX_NUM_WORKGROUPS, nBlocks);
100 int nBlocksPerGroup = (nBlocks+nWorkGroupsToExecute-1)/nWorkGroupsToExecute;
101 ADLASSERT( nWorkGroupsToExecute <= MAX_NUM_WORKGROUPS );
103 int4 constBuffer = make_int4(0, nBlocks, nWorkGroupsToExecute, nBlocksPerGroup);
107 for(int startBit=0; startBit<32; startBit+=4, iPass++)
109 constBuffer.x = startBit;
112 BufferInfo bInfo[] = { BufferInfo( src, true ), BufferInfo( data->m_workBuffer0 ) };
114 Launcher launcher( deviceData, data->m_localCountKernel );
115 launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
116 launcher.setConst( *data->m_constBuffer[iPass], constBuffer );
117 launcher.launch1D( WG_SIZE* nWorkGroupsToExecute, WG_SIZE );
122 BufferInfo bInfo[] = { BufferInfo( data->m_workBuffer0 ) };
124 Launcher launcher( deviceData, data->m_scanKernel );
125 launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
126 launcher.setConst( *data->m_constBuffer[iPass], constBuffer );
127 launcher.launch1D( WG_SIZE, WG_SIZE );
131 BufferInfo bInfo[] = { BufferInfo( data->m_workBuffer0, true ), BufferInfo( src ), BufferInfo( dst ) };
133 Launcher launcher( deviceData, data->m_scatterKernel );
134 launcher.setBuffers( bInfo, sizeof(bInfo)/sizeof(Launcher::BufferInfo) );
135 launcher.setConst( *data->m_constBuffer[iPass], constBuffer );
136 launcher.launch1D( WG_SIZE*nWorkGroupsToExecute, WG_SIZE );