1 | //************************************ bs::framework - Copyright 2018 Marko Pintera **************************************// |
2 | //*********** Licensed under the MIT license. See LICENSE.md for full terms. This notice is not to be removed. ***********// |
3 | #include "Utility/BsGpuSort.h" |
4 | #include "RenderAPI/BsGpuBuffer.h" |
5 | #include "Math/BsRandom.h" |
6 | #include "Renderer/BsRendererUtility.h" |
7 | |
8 | namespace bs { namespace ct |
9 | { |
10 | static constexpr UINT32 BIT_COUNT = 32; |
11 | static constexpr UINT32 RADIX_NUM_BITS = 4; |
12 | static constexpr UINT32 NUM_DIGITS = 1 << RADIX_NUM_BITS; |
13 | static constexpr UINT32 KEY_MASK = (NUM_DIGITS - 1); |
14 | static constexpr UINT32 NUM_PASSES = BIT_COUNT / RADIX_NUM_BITS; |
15 | |
16 | static constexpr UINT32 NUM_THREADS = 128; |
17 | static constexpr UINT32 KEYS_PER_LOOP = 8; |
18 | static constexpr UINT32 TILE_SIZE = NUM_THREADS * KEYS_PER_LOOP; |
19 | static constexpr UINT32 MAX_NUM_GROUPS = 64; |
20 | |
21 | RadixSortParamsDef gRadixSortParamsDef; |
22 | |
23 | /** Contains various constants required during the GpuSort algorithm. */ |
24 | struct GpuSortProperties |
25 | { |
26 | GpuSortProperties(UINT32 count) |
27 | : count(count) |
28 | { } |
29 | |
30 | const UINT32 count; |
31 | const UINT32 numTiles = count / TILE_SIZE; |
32 | const UINT32 numGroups = Math::clamp(numTiles, 1U, MAX_NUM_GROUPS); |
33 | |
34 | const UINT32 tilesPerGroup = numTiles / numGroups; |
35 | const UINT32 = numTiles % numGroups; |
36 | const UINT32 = count % TILE_SIZE; |
37 | }; |
38 | |
39 | /** Set up common defines required by all radix sort shaders. */ |
40 | void initCommonDefines(ShaderDefines& defines) |
41 | { |
42 | defines.set("RADIX_NUM_BITS" , RADIX_NUM_BITS); |
43 | defines.set("NUM_THREADS" , NUM_THREADS); |
44 | defines.set("KEYS_PER_LOOP" , KEYS_PER_LOOP); |
45 | defines.set("MAX_NUM_GROUPS" , MAX_NUM_GROUPS); |
46 | } |
47 | |
48 | void runSortTest(); |
49 | |
50 | /** |
51 | * Creates a new GPU parameter block buffer according to gRadixSortParamDef definition and writes GpuSort properties |
52 | * into the buffer. |
53 | */ |
54 | SPtr<GpuParamBlockBuffer> createGpuSortParams(const GpuSortProperties& props) |
55 | { |
56 | SPtr<GpuParamBlockBuffer> buffer = gRadixSortParamsDef.createBuffer(); |
57 | |
58 | gRadixSortParamsDef.gTilesPerGroup.set(buffer, props.tilesPerGroup); |
59 | gRadixSortParamsDef.gNumGroups.set(buffer, props.numGroups); |
60 | gRadixSortParamsDef.gNumExtraTiles.set(buffer, props.extraTiles); |
61 | gRadixSortParamsDef.gNumExtraKeys.set(buffer, props.extraKeys); |
62 | gRadixSortParamsDef.gBitOffset.set(buffer, 0); |
63 | |
64 | return buffer; |
65 | } |
66 | |
67 | /** |
68 | * Checks can the provided buffer be used for GPU sort operation. Returns a pointer to the error message if check failed |
69 | * or nullptr if check passed. |
70 | */ |
71 | const char* checkSortBuffer(GpuBuffer& buffer) |
72 | { |
73 | static constexpr const char* INVALID_GPU_WRITE_MSG = |
74 | "All buffers provided to GpuSort must be created with GBU_LOADSTORE flags enabled." ; |
75 | static constexpr const char* INVALID_TYPE_MSG = |
76 | "All buffers provided to GpuSort must be of GBT_STANDARD type." ; |
77 | static constexpr const char* INVALID_FORMAT_MSG = |
78 | "All buffers provided to GpuSort must use a 32-bit unsigned integer format." ; |
79 | |
80 | const GpuBufferProperties& bufferProps = buffer.getProperties(); |
81 | if ((bufferProps.getUsage() & GBU_LOADSTORE) != GBU_LOADSTORE) |
82 | return INVALID_GPU_WRITE_MSG; |
83 | |
84 | if(bufferProps.getType() != GBT_STANDARD) |
85 | return INVALID_TYPE_MSG; |
86 | |
87 | if(bufferProps.getFormat() != BF_32X1U) |
88 | return INVALID_FORMAT_MSG; |
89 | |
90 | return nullptr; |
91 | } |
92 | |
93 | /** Creates a helper buffers used for storing intermediate information during GpuSort::sort. */ |
94 | SPtr<GpuBuffer> createHelperBuffer() |
95 | { |
96 | GPU_BUFFER_DESC desc; |
97 | desc.elementCount = MAX_NUM_GROUPS * NUM_DIGITS; |
98 | desc.format = BF_32X1U; |
99 | desc.usage = GBU_LOADSTORE; |
100 | desc.type = GBT_STANDARD; |
101 | |
102 | return GpuBuffer::create(desc); |
103 | } |
104 | |
105 | RadixSortClearMat::RadixSortClearMat() |
106 | { |
107 | mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gOutput" , mOutputParam); |
108 | } |
109 | |
110 | void RadixSortClearMat::_initDefines(ShaderDefines& defines) |
111 | { |
112 | initCommonDefines(defines); |
113 | } |
114 | |
115 | void RadixSortClearMat::execute(const SPtr<GpuBuffer>& outputOffsets) |
116 | { |
117 | BS_RENMAT_PROFILE_BLOCK |
118 | |
119 | mOutputParam.set(outputOffsets); |
120 | |
121 | bind(); |
122 | RenderAPI::instance().dispatchCompute(1); |
123 | } |
124 | |
125 | RadixSortCountMat::RadixSortCountMat() |
126 | { |
127 | mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gInputKeys" , mInputKeysParam); |
128 | mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gOutputCounts" , mOutputCountsParam); |
129 | } |
130 | |
131 | void RadixSortCountMat::_initDefines(ShaderDefines& defines) |
132 | { |
133 | initCommonDefines(defines); |
134 | } |
135 | |
136 | void RadixSortCountMat::execute(UINT32 numGroups, const SPtr<GpuParamBlockBuffer>& params, |
137 | const SPtr<GpuBuffer>& inputKeys, const SPtr<GpuBuffer>& outputOffsets) |
138 | { |
139 | BS_RENMAT_PROFILE_BLOCK |
140 | |
141 | mInputKeysParam.set(inputKeys); |
142 | mOutputCountsParam.set(outputOffsets); |
143 | |
144 | mParams->setParamBlockBuffer("Params" , params); |
145 | |
146 | bind(); |
147 | RenderAPI::instance().dispatchCompute(numGroups); |
148 | } |
149 | |
150 | RadixSortPrefixScanMat::RadixSortPrefixScanMat() |
151 | { |
152 | mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gInputCounts" , mInputCountsParam); |
153 | mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gOutputOffsets" , mOutputOffsetsParam); |
154 | } |
155 | |
156 | void RadixSortPrefixScanMat::_initDefines(ShaderDefines& defines) |
157 | { |
158 | initCommonDefines(defines); |
159 | } |
160 | |
161 | void RadixSortPrefixScanMat::execute(const SPtr<GpuParamBlockBuffer>& params, const SPtr<GpuBuffer>& inputCounts, |
162 | const SPtr<GpuBuffer>& outputOffsets) |
163 | { |
164 | BS_RENMAT_PROFILE_BLOCK |
165 | |
166 | mInputCountsParam.set(inputCounts); |
167 | mOutputOffsetsParam.set(outputOffsets); |
168 | |
169 | mParams->setParamBlockBuffer("Params" , params); |
170 | |
171 | bind(); |
172 | RenderAPI::instance().dispatchCompute(1); |
173 | } |
174 | |
175 | RadixSortReorderMat::RadixSortReorderMat() |
176 | { |
177 | mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gInputOffsets" , mInputOffsetsBufferParam); |
178 | mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gInputKeys" , mInputKeysBufferParam); |
179 | mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gInputValues" , mInputValuesBufferParam); |
180 | mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gOutputKeys" , mOutputKeysBufferParam); |
181 | mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gOutputValues" , mOutputValuesBufferParam); |
182 | } |
183 | |
184 | void RadixSortReorderMat::_initDefines(ShaderDefines& defines) |
185 | { |
186 | initCommonDefines(defines); |
187 | } |
188 | |
189 | void RadixSortReorderMat::execute(UINT32 numGroups, const SPtr<GpuParamBlockBuffer>& params, |
190 | const SPtr<GpuBuffer>& inputPrefix, const GpuSortBuffers& buffers, UINT32 inputBufferIdx) |
191 | { |
192 | BS_RENMAT_PROFILE_BLOCK |
193 | |
194 | const UINT32 outputBufferIdx = (inputBufferIdx + 1) % 2; |
195 | |
196 | mInputOffsetsBufferParam.set(inputPrefix); |
197 | mInputKeysBufferParam.set(buffers.keys[inputBufferIdx]); |
198 | mInputValuesBufferParam.set(buffers.values[inputBufferIdx]); |
199 | mOutputKeysBufferParam.set(buffers.keys[outputBufferIdx]); |
200 | mOutputValuesBufferParam.set(buffers.values[outputBufferIdx]); |
201 | |
202 | mParams->setParamBlockBuffer("Params" , params); |
203 | |
204 | bind(); |
205 | RenderAPI::instance().dispatchCompute(numGroups); |
206 | } |
207 | |
208 | GpuSort::GpuSort() |
209 | { |
210 | mHelperBuffers[0] = createHelperBuffer(); |
211 | mHelperBuffers[1] = createHelperBuffer(); |
212 | } |
213 | |
214 | UINT32 GpuSort::sort(const GpuSortBuffers& buffers, UINT32 numKeys, UINT32 keyMask) |
215 | { |
216 | // Nothing to do if no input or output key buffers |
217 | if(buffers.keys[0] == nullptr || buffers.keys[1] == nullptr) |
218 | return 0; |
219 | |
220 | // Check if all buffers have been created with required options |
221 | const char* errorMsg = nullptr; |
222 | for(UINT32 i = 0; i < 2; i++) |
223 | { |
224 | errorMsg = checkSortBuffer(*buffers.keys[i]); |
225 | if(errorMsg) break; |
226 | |
227 | if(buffers.values[i]) |
228 | { |
229 | errorMsg = checkSortBuffer(*buffers.values[i]); |
230 | if(errorMsg) break; |
231 | } |
232 | } |
233 | |
234 | if(errorMsg) |
235 | { |
236 | LOGERR("GpuSort failed: " + String(errorMsg)); |
237 | return 0; |
238 | } |
239 | |
240 | // Check if all buffers have the same size |
241 | bool validSize = buffers.keys[0]->getSize() == buffers.keys[1]->getSize(); |
242 | if(buffers.values[0] && buffers.values[1]) |
243 | { |
244 | validSize = buffers.keys[0]->getSize() == buffers.values[0]->getSize() && |
245 | buffers.keys[0]->getSize() == buffers.values[1]->getSize(); |
246 | |
247 | } |
248 | |
249 | if (!validSize) |
250 | { |
251 | LOGERR("GpuSort failed: All sort buffers must have the same size." ); |
252 | return 0; |
253 | } |
254 | |
255 | const GpuSortProperties gpuSortProps(numKeys); |
256 | SPtr<GpuParamBlockBuffer> params = createGpuSortParams(gpuSortProps); |
257 | |
258 | UINT32 bitOffset = 0; |
259 | UINT32 inputBufferIdx = 0; |
260 | for(UINT32 i = 0; i < NUM_PASSES; i++) |
261 | { |
262 | if(((KEY_MASK << bitOffset) & keyMask) != 0) |
263 | { |
264 | gRadixSortParamsDef.gBitOffset.set(params, bitOffset); |
265 | |
266 | RadixSortClearMat::get()->execute(mHelperBuffers[0]); |
267 | RadixSortCountMat::get()->execute(gpuSortProps.numGroups, params, buffers.keys[inputBufferIdx], mHelperBuffers[0]); |
268 | RadixSortPrefixScanMat::get()->execute(params, mHelperBuffers[0], mHelperBuffers[1]); |
269 | RadixSortReorderMat::get()->execute(gpuSortProps.numGroups, params, mHelperBuffers[1], buffers, inputBufferIdx); |
270 | |
271 | inputBufferIdx = (inputBufferIdx + 1) % 2; |
272 | } |
273 | |
274 | bitOffset += RADIX_NUM_BITS; |
275 | } |
276 | |
277 | return inputBufferIdx; |
278 | } |
279 | |
280 | GpuSortBuffers GpuSort::createSortBuffers(UINT32 numElements, bool values) |
281 | { |
282 | GpuSortBuffers output; |
283 | |
284 | GPU_BUFFER_DESC bufferDesc; |
285 | bufferDesc.elementCount = numElements; |
286 | bufferDesc.format = BF_32X1U; |
287 | bufferDesc.type = GBT_STANDARD; |
288 | bufferDesc.usage = GBU_LOADSTORE; |
289 | |
290 | output.keys[0] = GpuBuffer::create(bufferDesc); |
291 | output.keys[1] = GpuBuffer::create(bufferDesc); |
292 | |
293 | if(values) |
294 | { |
295 | output.values[0] = GpuBuffer::create(bufferDesc); |
296 | output.values[1] = GpuBuffer::create(bufferDesc); |
297 | } |
298 | |
299 | return output; |
300 | } |
301 | |
302 | // Note: This test isn't currently hooked up anywhere. It might be a good idea to set it up as a unit test, but it would |
303 | // require exposing parts of GpuSort to the public, which I don't feel like it's worth doing just for a test. So instead |
304 | // just make sure to run the test below if you modify any of the GpuSort code. |
305 | void runSortTest() |
306 | { |
307 | // Generate test keys |
308 | static constexpr UINT32 NUM_INPUT_KEYS = 10000; |
309 | Vector<UINT32> inputKeys; |
310 | inputKeys.reserve(NUM_INPUT_KEYS); |
311 | |
312 | Random random; |
313 | for(UINT32 i = 0; i < NUM_INPUT_KEYS; i++) |
314 | inputKeys.push_back(random.getRange(0, 15) << 4 | std::min(NUM_DIGITS - 1, (i / (NUM_INPUT_KEYS / 16)))); |
315 | |
316 | const auto count = (UINT32)inputKeys.size(); |
317 | UINT32 bitOffset = 4; |
318 | UINT32 bitMask = (1 << RADIX_NUM_BITS) - 1; |
319 | |
320 | // Prepare buffers |
321 | const GpuSortProperties gpuSortProps(count); |
322 | SPtr<GpuParamBlockBuffer> params = createGpuSortParams(gpuSortProps); |
323 | |
324 | gRadixSortParamsDef.gBitOffset.set(params, bitOffset); |
325 | |
326 | GpuSortBuffers sortBuffers = GpuSort::createSortBuffers(count); |
327 | sortBuffers.keys[0]->writeData(0, sortBuffers.keys[0]->getSize(), inputKeys.data(), BWT_DISCARD); |
328 | |
329 | SPtr<GpuBuffer> helperBuffers[2]; |
330 | helperBuffers[0] = createHelperBuffer(); |
331 | helperBuffers[1] = createHelperBuffer(); |
332 | |
333 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// |
334 | //////////////////////////////////////////// Count keys per group ////////////////////////////////////////////////// |
335 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// |
336 | |
337 | // SERIAL: |
338 | Vector<UINT32> counts(gpuSortProps.numGroups * NUM_DIGITS); |
339 | for(UINT32 groupIdx = 0; groupIdx < gpuSortProps.numGroups; groupIdx++) |
340 | { |
341 | // Count keys per thread |
342 | UINT32 localCounts[NUM_THREADS * NUM_DIGITS] = { 0 }; |
343 | |
344 | UINT32 tileIdx; |
345 | UINT32 numTiles; |
346 | if(groupIdx < gpuSortProps.extraTiles) |
347 | { |
348 | numTiles = gpuSortProps.tilesPerGroup + 1; |
349 | tileIdx = groupIdx * numTiles; |
350 | } |
351 | else |
352 | { |
353 | numTiles = gpuSortProps.tilesPerGroup; |
354 | tileIdx = groupIdx * numTiles + gpuSortProps.extraTiles; |
355 | } |
356 | |
357 | UINT32 keyBegin = tileIdx * TILE_SIZE; |
358 | UINT32 keyEnd = keyBegin + numTiles * TILE_SIZE; |
359 | |
360 | while(keyBegin < keyEnd) |
361 | { |
362 | for(UINT32 threadIdx = 0; threadIdx < NUM_THREADS; threadIdx++) |
363 | { |
364 | UINT32 key = inputKeys[keyBegin + threadIdx]; |
365 | UINT32 digit = (key >> bitOffset) & bitMask; |
366 | |
367 | localCounts[threadIdx * NUM_DIGITS + digit] += 1; |
368 | } |
369 | |
370 | keyBegin += NUM_THREADS; |
371 | } |
372 | |
373 | if(groupIdx == (gpuSortProps.numGroups - 1)) |
374 | { |
375 | keyBegin = keyEnd; |
376 | keyEnd = keyBegin + gpuSortProps.extraKeys; |
377 | |
378 | while(keyBegin < keyEnd) |
379 | { |
380 | for (UINT32 threadIdx = 0; threadIdx < NUM_THREADS; threadIdx++) |
381 | { |
382 | if((keyBegin + threadIdx) < keyEnd) |
383 | { |
384 | UINT32 key = inputKeys[keyBegin + threadIdx]; |
385 | UINT32 digit = (key >> bitOffset) & bitMask; |
386 | |
387 | localCounts[threadIdx * NUM_DIGITS + digit] += 1; |
388 | } |
389 | } |
390 | |
391 | keyBegin += NUM_THREADS; |
392 | } |
393 | } |
394 | |
395 | // Sum up all key counts in a group |
396 | static constexpr UINT32 NUM_REDUCE_THREADS = 64; |
397 | static constexpr UINT32 NUM_REDUCE_THREADS_PER_DIGIT = NUM_REDUCE_THREADS / NUM_DIGITS; |
398 | static constexpr UINT32 NUM_REDUCE_ELEMS_PER_THREAD_PER_DIGIT = NUM_THREADS / NUM_REDUCE_THREADS_PER_DIGIT; |
399 | |
400 | UINT32 reduceCounters[NUM_REDUCE_THREADS] = { 0 }; |
401 | UINT32 reduceTotals[NUM_REDUCE_THREADS] = { 0 }; |
402 | for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++) |
403 | { |
404 | if(threadId < NUM_REDUCE_THREADS) |
405 | { |
406 | UINT32 digitIdx = threadId / NUM_REDUCE_THREADS_PER_DIGIT; |
407 | UINT32 setIdx = threadId & (NUM_REDUCE_THREADS_PER_DIGIT - 1); |
408 | |
409 | UINT32 total = 0; |
410 | for(UINT32 i = 0; i < NUM_REDUCE_ELEMS_PER_THREAD_PER_DIGIT; i++) |
411 | { |
412 | UINT32 threadIdx = (setIdx * NUM_REDUCE_ELEMS_PER_THREAD_PER_DIGIT + i) * NUM_DIGITS; |
413 | total += localCounts[threadIdx + digitIdx]; |
414 | } |
415 | |
416 | reduceCounters[digitIdx * NUM_REDUCE_THREADS_PER_DIGIT + setIdx] = total; |
417 | reduceTotals[threadId] = total; |
418 | |
419 | } |
420 | } |
421 | |
422 | // And do parallel reduction on the result of serial additions |
423 | for (UINT32 i = 1; i < NUM_REDUCE_THREADS_PER_DIGIT; i <<= 1) |
424 | { |
425 | for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++) |
426 | { |
427 | if (threadId < NUM_REDUCE_THREADS) |
428 | { |
429 | UINT32 digitIdx = threadId / NUM_REDUCE_THREADS_PER_DIGIT; |
430 | UINT32 setIdx = threadId & (NUM_REDUCE_THREADS_PER_DIGIT - 1); |
431 | |
432 | reduceTotals[threadId] += reduceCounters[digitIdx * NUM_REDUCE_THREADS_PER_DIGIT + setIdx + i]; |
433 | } |
434 | } |
435 | |
436 | for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++) |
437 | { |
438 | if (threadId < NUM_REDUCE_THREADS) |
439 | { |
440 | UINT32 digitIdx = threadId / NUM_REDUCE_THREADS_PER_DIGIT; |
441 | UINT32 setIdx = threadId & (NUM_REDUCE_THREADS_PER_DIGIT - 1); |
442 | |
443 | reduceCounters[digitIdx * NUM_REDUCE_THREADS_PER_DIGIT + setIdx] = reduceTotals[threadId]; |
444 | } |
445 | } |
446 | } |
447 | |
448 | for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++) |
449 | { |
450 | if(threadId < NUM_DIGITS) |
451 | counts[groupIdx * NUM_DIGITS + threadId] = reduceCounters[threadId * NUM_REDUCE_THREADS_PER_DIGIT]; |
452 | } |
453 | } |
454 | |
455 | // PARALLEL: |
456 | RadixSortClearMat::get()->execute(helperBuffers[0]); |
457 | RadixSortCountMat::get()->execute(gpuSortProps.numGroups, params, sortBuffers.keys[0], helperBuffers[0]); |
458 | RenderAPI::instance().submitCommandBuffer(nullptr); |
459 | |
460 | // Compare with GPU count |
461 | const UINT32 helperBufferLength = helperBuffers[0]->getProperties().getElementCount(); |
462 | Vector<UINT32> bufferCounts(helperBufferLength); |
463 | helperBuffers[0]->readData(0, helperBufferLength * sizeof(UINT32), bufferCounts.data()); |
464 | |
465 | for(UINT32 i = 0; i < (UINT32)counts.size(); i++) |
466 | assert(bufferCounts[i] == counts[i]); |
467 | |
468 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// |
469 | /////////////////////////////////////////////// Calculate offsets ////////////////////////////////////////////////// |
470 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// |
471 | |
472 | // SERIAL: |
473 | // Prefix sum over per-digit counts over all groups |
474 | Vector<UINT32> perDigitPrefixSum(NUM_DIGITS * MAX_NUM_GROUPS); |
475 | for(UINT32 groupIdx = 0; groupIdx < gpuSortProps.numGroups; groupIdx++) |
476 | { |
477 | for (UINT32 j = 0; j < NUM_DIGITS; j++) |
478 | perDigitPrefixSum[groupIdx * NUM_DIGITS + j] = counts[groupIdx * NUM_DIGITS + j]; |
479 | } |
480 | |
481 | // Prefix sum over per-digit counts over all groups |
482 | //// Upsweep |
483 | UINT32 offset = 1; |
484 | for (UINT32 i = MAX_NUM_GROUPS >> 1; i > 0; i >>= 1) |
485 | { |
486 | for (UINT32 groupIdx = 0; groupIdx < MAX_NUM_GROUPS; groupIdx++) |
487 | { |
488 | if (groupIdx < i) |
489 | { |
490 | for (UINT32 j = 0; j < NUM_DIGITS; j++) |
491 | { |
492 | UINT32 idx0 = (offset * (2 * groupIdx + 1) - 1) * NUM_DIGITS + j; |
493 | UINT32 idx1 = (offset * (2 * groupIdx + 2) - 1) * NUM_DIGITS + j; |
494 | |
495 | perDigitPrefixSum[idx1] += perDigitPrefixSum[idx0]; |
496 | } |
497 | } |
498 | } |
499 | |
500 | offset <<= 1; |
501 | } |
502 | |
503 | //// Downsweep |
504 | UINT32 totalsPrefixSum[NUM_DIGITS] = { 0 }; |
505 | for(UINT32 groupIdx = 0; groupIdx < NUM_DIGITS; groupIdx++) |
506 | { |
507 | if (groupIdx < NUM_DIGITS) |
508 | { |
509 | UINT32 idx = (MAX_NUM_GROUPS - 1) * NUM_DIGITS + groupIdx; |
510 | totalsPrefixSum[groupIdx] = perDigitPrefixSum[idx]; |
511 | perDigitPrefixSum[idx] = 0; |
512 | } |
513 | } |
514 | |
515 | for (UINT32 i = 1; i < MAX_NUM_GROUPS; i <<= 1) |
516 | { |
517 | offset >>= 1; |
518 | |
519 | for (UINT32 groupIdx = 0; groupIdx < MAX_NUM_GROUPS; groupIdx++) |
520 | { |
521 | if (groupIdx < i) |
522 | { |
523 | for (UINT32 j = 0; j < NUM_DIGITS; j++) |
524 | { |
525 | UINT32 idx0 = (offset * (2 * groupIdx + 1) - 1) * NUM_DIGITS + j; |
526 | UINT32 idx1 = (offset * (2 * groupIdx + 2) - 1) * NUM_DIGITS + j; |
527 | |
528 | UINT32 temp = perDigitPrefixSum[idx0]; |
529 | perDigitPrefixSum[idx0] = perDigitPrefixSum[idx1]; |
530 | perDigitPrefixSum[idx1] += temp; |
531 | } |
532 | } |
533 | } |
534 | } |
535 | |
536 | // Prefix sum over the total count |
537 | for(UINT32 i = 1; i < NUM_DIGITS; i++) |
538 | totalsPrefixSum[i] += totalsPrefixSum[i - 1]; |
539 | |
540 | // Make it exclusive by shifting |
541 | for(UINT32 i = NUM_DIGITS - 1; i > 0; i--) |
542 | totalsPrefixSum[i] = totalsPrefixSum[i - 1]; |
543 | |
544 | totalsPrefixSum[0] = 0; |
545 | |
546 | Vector<UINT32> offsets(gpuSortProps.numGroups * NUM_DIGITS); |
547 | for (UINT32 groupIdx = 0; groupIdx < gpuSortProps.numGroups; groupIdx++) |
548 | { |
549 | for (UINT32 i = 0; i < NUM_DIGITS; i++) |
550 | offsets[groupIdx * NUM_DIGITS + i] = totalsPrefixSum[i] + perDigitPrefixSum[groupIdx * NUM_DIGITS + i]; |
551 | } |
552 | |
553 | // PARALLEL: |
554 | RadixSortPrefixScanMat::get()->execute(params, helperBuffers[0], helperBuffers[1]); |
555 | RenderAPI::instance().submitCommandBuffer(nullptr); |
556 | |
557 | // Compare with GPU offsets |
558 | Vector<UINT32> bufferOffsets(helperBufferLength); |
559 | helperBuffers[1]->readData(0, helperBufferLength * sizeof(UINT32), bufferOffsets.data()); |
560 | |
561 | for(UINT32 i = 0; i < (UINT32)offsets.size(); i++) |
562 | assert(bufferOffsets[i] == offsets[i]); |
563 | |
564 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// |
565 | /////////////////////////////////////////////////// Reorder //////////////////////////////////////////////////////// |
566 | //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// |
567 | |
568 | // SERIAL: |
569 | // Reorder within each tile |
570 | Vector<UINT32> sortedKeys(inputKeys.size()); |
571 | UINT32 sGroupOffsets[NUM_DIGITS]; |
572 | UINT32 sLocalScratch[NUM_DIGITS * NUM_THREADS]; |
573 | UINT32 sTileTotals[NUM_DIGITS]; |
574 | UINT32 sCurrentTileTotal[NUM_DIGITS]; |
575 | |
576 | for(UINT32 groupIdx = 0; groupIdx < gpuSortProps.numGroups; groupIdx++) |
577 | { |
578 | for(UINT32 i = 0; i < NUM_DIGITS; i++) |
579 | { |
580 | // Load offsets for this group to local memory |
581 | sGroupOffsets[i] = offsets[groupIdx * NUM_DIGITS + i]; |
582 | |
583 | // Clear tile totals |
584 | sTileTotals[i] = 0; |
585 | } |
586 | |
587 | // Handle case when number of tiles isn't exactly divisible by number of groups, in |
588 | // which case first N groups handle those extra tiles |
589 | UINT32 tileIdx; |
590 | UINT32 numTiles; |
591 | if(groupIdx < gpuSortProps.extraTiles) |
592 | { |
593 | numTiles = gpuSortProps.tilesPerGroup + 1; |
594 | tileIdx = groupIdx * numTiles; |
595 | } |
596 | else |
597 | { |
598 | numTiles = gpuSortProps.tilesPerGroup; |
599 | tileIdx = groupIdx * numTiles + gpuSortProps.extraTiles; |
600 | } |
601 | |
602 | // We need to generate per-thread offsets (prefix sum) of where to store the keys at |
603 | // (This is equivalent to what was done in count & prefix sum shaders, except that was done per-group) |
604 | |
605 | //// First, count all digits |
606 | UINT32 keyBegin[NUM_THREADS]; |
607 | for(UINT32 threadId = 0; threadId < NUM_THREADS; threadId++) |
608 | keyBegin[threadId] = tileIdx * TILE_SIZE; |
609 | |
610 | auto prefixSum = [&sLocalScratch, &sCurrentTileTotal]() |
611 | { |
612 | // Upsweep to generate partial sums |
613 | UINT32 offsets[NUM_THREADS]; |
614 | for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++) |
615 | offsets[threadId] = 1; |
616 | |
617 | for (UINT32 i = NUM_THREADS >> 1; i > 0; i >>= 1) |
618 | { |
619 | for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++) |
620 | { |
621 | if (threadId < i) |
622 | { |
623 | // Note: If I run more than NUM_THREADS threads I wouldn't have to |
624 | // iterate over all digits in a single thread |
625 | // Note: Perhaps run part of this step serially for better performance |
626 | for (UINT32 j = 0; j < NUM_DIGITS; j++) |
627 | { |
628 | UINT32 idx0 = (offsets[threadId] * (2 * threadId + 1) - 1) * NUM_DIGITS + j; |
629 | UINT32 idx1 = (offsets[threadId] * (2 * threadId + 2) - 1) * NUM_DIGITS + j; |
630 | |
631 | // Note: Check and remove bank conflicts |
632 | sLocalScratch[idx1] += sLocalScratch[idx0]; |
633 | } |
634 | } |
635 | |
636 | offsets[threadId] <<= 1; |
637 | } |
638 | } |
639 | |
640 | // Set tree roots to zero (prepare for downsweep) |
641 | for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++) |
642 | { |
643 | if (threadId < NUM_DIGITS) |
644 | { |
645 | UINT32 idx = (NUM_THREADS - 1) * NUM_DIGITS + threadId; |
646 | sCurrentTileTotal[threadId] = sLocalScratch[idx]; |
647 | |
648 | sLocalScratch[idx] = 0; |
649 | } |
650 | } |
651 | |
652 | // Downsweep to calculate the prefix sum from partial sums that were generated |
653 | // during upsweep |
654 | for (UINT32 i = 1; i < NUM_THREADS; i <<= 1) |
655 | { |
656 | for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++) |
657 | { |
658 | offsets[threadId] >>= 1; |
659 | |
660 | if (threadId < i) |
661 | { |
662 | for (UINT32 j = 0; j < NUM_DIGITS; j++) |
663 | { |
664 | UINT32 idx0 = (offsets[threadId] * (2 * threadId + 1) - 1) * NUM_DIGITS + j; |
665 | UINT32 idx1 = (offsets[threadId] * (2 * threadId + 2) - 1) * NUM_DIGITS + j; |
666 | |
667 | // Note: Check and resolve bank conflicts |
668 | UINT32 temp = sLocalScratch[idx0]; |
669 | sLocalScratch[idx0] = sLocalScratch[idx1]; |
670 | sLocalScratch[idx1] += temp; |
671 | } |
672 | } |
673 | } |
674 | } |
675 | }; |
676 | |
677 | for(UINT32 tileIdx = 0; tileIdx < numTiles; tileIdx++) |
678 | { |
679 | // Zero out local counter |
680 | for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++) |
681 | for (UINT32 i = 0; i < NUM_DIGITS; i++) |
682 | sLocalScratch[i * NUM_THREADS + threadId] = 0; |
683 | |
684 | for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++) |
685 | { |
686 | for (UINT32 i = 0; i < KEYS_PER_LOOP; i++) |
687 | { |
688 | UINT32 idx = keyBegin[threadId] + threadId * KEYS_PER_LOOP + i; |
689 | UINT32 key = inputKeys[idx]; |
690 | UINT32 digit = (key >> bitOffset) & KEY_MASK; |
691 | |
692 | sLocalScratch[threadId * NUM_DIGITS + digit] += 1; |
693 | } |
694 | } |
695 | |
696 | prefixSum(); |
697 | |
698 | // Actually re-order the keys |
699 | for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++) |
700 | { |
701 | UINT32 localOffsets[NUM_DIGITS]; |
702 | for (UINT32 i = 0; i < NUM_DIGITS; i++) |
703 | localOffsets[i] = 0; |
704 | |
705 | for (UINT32 i = 0; i < KEYS_PER_LOOP; i++) |
706 | { |
707 | UINT32 idx = keyBegin[threadId] + threadId * KEYS_PER_LOOP + i; |
708 | UINT32 key = inputKeys[idx]; |
709 | UINT32 digit = (key >> bitOffset) & KEY_MASK; |
710 | |
711 | UINT32 offset = sGroupOffsets[digit] + sTileTotals[digit] + sLocalScratch[threadId * NUM_DIGITS + digit] + localOffsets[digit]; |
712 | localOffsets[digit]++; |
713 | |
714 | // Note: First write to local memory then attempt to coalesce when writing to global? |
715 | sortedKeys[offset] = key; |
716 | } |
717 | } |
718 | |
719 | // Update tile totals |
720 | for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++) |
721 | { |
722 | if (threadId < NUM_DIGITS) |
723 | sTileTotals[threadId] += sCurrentTileTotal[threadId]; |
724 | } |
725 | |
726 | for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++) |
727 | keyBegin[threadId] += TILE_SIZE; |
728 | } |
729 | |
730 | if (groupIdx == (gpuSortProps.numGroups - 1) && gpuSortProps.extraKeys > 0) |
731 | { |
732 | // Zero out local counter |
733 | for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++) |
734 | for (UINT32 i = 0; i < NUM_DIGITS; i++) |
735 | sLocalScratch[i * NUM_THREADS + threadId] = 0; |
736 | |
737 | for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++) |
738 | { |
739 | for (UINT32 i = 0; i < KEYS_PER_LOOP; i++) |
740 | { |
741 | UINT32 localIdx = threadId * KEYS_PER_LOOP + i; |
742 | |
743 | if (localIdx >= gpuSortProps.extraKeys) |
744 | continue; |
745 | |
746 | UINT32 idx = keyBegin[threadId] + localIdx; |
747 | UINT32 key = inputKeys[idx]; |
748 | UINT32 digit = (key >> bitOffset) & KEY_MASK; |
749 | |
750 | sLocalScratch[threadId * NUM_DIGITS + digit] += 1; |
751 | } |
752 | } |
753 | |
754 | prefixSum(); |
755 | |
756 | // Actually re-order the keys |
757 | for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++) |
758 | { |
759 | UINT32 localOffsets[NUM_DIGITS]; |
760 | for (UINT32 i = 0; i < NUM_DIGITS; i++) |
761 | localOffsets[i] = 0; |
762 | |
763 | for (UINT32 i = 0; i < KEYS_PER_LOOP; i++) |
764 | { |
765 | UINT32 localIdx = threadId * KEYS_PER_LOOP + i; |
766 | |
767 | if (localIdx >= gpuSortProps.extraKeys) |
768 | continue; |
769 | |
770 | UINT32 idx = keyBegin[threadId] + localIdx; |
771 | UINT32 key = inputKeys[idx]; |
772 | UINT32 digit = (key >> bitOffset) & KEY_MASK; |
773 | |
774 | UINT32 offset = sGroupOffsets[digit] + sTileTotals[digit] + sLocalScratch[threadId * NUM_DIGITS + digit] + localOffsets[digit]; |
775 | localOffsets[digit]++; |
776 | |
777 | // Note: First write to local memory then attempt to coalesce when writing to global? |
778 | sortedKeys[offset] = key; |
779 | } |
780 | } |
781 | } |
782 | } |
783 | |
784 | // PARALLEL: |
785 | RadixSortReorderMat::get()->execute(gpuSortProps.numGroups, params, helperBuffers[1], sortBuffers, 0); |
786 | RenderAPI::instance().submitCommandBuffer(nullptr); |
787 | |
788 | // Compare with GPU keys |
789 | Vector<UINT32> bufferSortedKeys(count); |
790 | sortBuffers.keys[1]->readData(0, count * sizeof(UINT32), bufferSortedKeys.data()); |
791 | |
792 | for(UINT32 i = 0; i < count; i++) |
793 | assert(bufferSortedKeys[i] == sortedKeys[i]); |
794 | |
795 | // Ensure everything is actually sorted |
796 | assert(std::is_sorted(sortedKeys.begin(), sortedKeys.end())); |
797 | } |
798 | }} |
799 | |