1//************************************ bs::framework - Copyright 2018 Marko Pintera **************************************//
2//*********** Licensed under the MIT license. See LICENSE.md for full terms. This notice is not to be removed. ***********//
3#include "Utility/BsGpuSort.h"
4#include "RenderAPI/BsGpuBuffer.h"
5#include "Math/BsRandom.h"
6#include "Renderer/BsRendererUtility.h"
7
8namespace bs { namespace ct
9{
10 static constexpr UINT32 BIT_COUNT = 32;
11 static constexpr UINT32 RADIX_NUM_BITS = 4;
12 static constexpr UINT32 NUM_DIGITS = 1 << RADIX_NUM_BITS;
13 static constexpr UINT32 KEY_MASK = (NUM_DIGITS - 1);
14 static constexpr UINT32 NUM_PASSES = BIT_COUNT / RADIX_NUM_BITS;
15
16 static constexpr UINT32 NUM_THREADS = 128;
17 static constexpr UINT32 KEYS_PER_LOOP = 8;
18 static constexpr UINT32 TILE_SIZE = NUM_THREADS * KEYS_PER_LOOP;
19 static constexpr UINT32 MAX_NUM_GROUPS = 64;
20
21 RadixSortParamsDef gRadixSortParamsDef;
22
23 /** Contains various constants required during the GpuSort algorithm. */
24 struct GpuSortProperties
25 {
26 GpuSortProperties(UINT32 count)
27 : count(count)
28 { }
29
30 const UINT32 count;
31 const UINT32 numTiles = count / TILE_SIZE;
32 const UINT32 numGroups = Math::clamp(numTiles, 1U, MAX_NUM_GROUPS);
33
34 const UINT32 tilesPerGroup = numTiles / numGroups;
35 const UINT32 extraTiles = numTiles % numGroups;
36 const UINT32 extraKeys = count % TILE_SIZE;
37 };
38
39 /** Set up common defines required by all radix sort shaders. */
40 void initCommonDefines(ShaderDefines& defines)
41 {
42 defines.set("RADIX_NUM_BITS", RADIX_NUM_BITS);
43 defines.set("NUM_THREADS", NUM_THREADS);
44 defines.set("KEYS_PER_LOOP", KEYS_PER_LOOP);
45 defines.set("MAX_NUM_GROUPS", MAX_NUM_GROUPS);
46 }
47
48 void runSortTest();
49
50 /**
51 * Creates a new GPU parameter block buffer according to gRadixSortParamDef definition and writes GpuSort properties
52 * into the buffer.
53 */
54 SPtr<GpuParamBlockBuffer> createGpuSortParams(const GpuSortProperties& props)
55 {
56 SPtr<GpuParamBlockBuffer> buffer = gRadixSortParamsDef.createBuffer();
57
58 gRadixSortParamsDef.gTilesPerGroup.set(buffer, props.tilesPerGroup);
59 gRadixSortParamsDef.gNumGroups.set(buffer, props.numGroups);
60 gRadixSortParamsDef.gNumExtraTiles.set(buffer, props.extraTiles);
61 gRadixSortParamsDef.gNumExtraKeys.set(buffer, props.extraKeys);
62 gRadixSortParamsDef.gBitOffset.set(buffer, 0);
63
64 return buffer;
65 }
66
67 /**
68 * Checks can the provided buffer be used for GPU sort operation. Returns a pointer to the error message if check failed
69 * or nullptr if check passed.
70 */
71 const char* checkSortBuffer(GpuBuffer& buffer)
72 {
73 static constexpr const char* INVALID_GPU_WRITE_MSG =
74 "All buffers provided to GpuSort must be created with GBU_LOADSTORE flags enabled.";
75 static constexpr const char* INVALID_TYPE_MSG =
76 "All buffers provided to GpuSort must be of GBT_STANDARD type.";
77 static constexpr const char* INVALID_FORMAT_MSG =
78 "All buffers provided to GpuSort must use a 32-bit unsigned integer format.";
79
80 const GpuBufferProperties& bufferProps = buffer.getProperties();
81 if ((bufferProps.getUsage() & GBU_LOADSTORE) != GBU_LOADSTORE)
82 return INVALID_GPU_WRITE_MSG;
83
84 if(bufferProps.getType() != GBT_STANDARD)
85 return INVALID_TYPE_MSG;
86
87 if(bufferProps.getFormat() != BF_32X1U)
88 return INVALID_FORMAT_MSG;
89
90 return nullptr;
91 }
92
93 /** Creates a helper buffers used for storing intermediate information during GpuSort::sort. */
94 SPtr<GpuBuffer> createHelperBuffer()
95 {
96 GPU_BUFFER_DESC desc;
97 desc.elementCount = MAX_NUM_GROUPS * NUM_DIGITS;
98 desc.format = BF_32X1U;
99 desc.usage = GBU_LOADSTORE;
100 desc.type = GBT_STANDARD;
101
102 return GpuBuffer::create(desc);
103 }
104
105 RadixSortClearMat::RadixSortClearMat()
106 {
107 mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gOutput", mOutputParam);
108 }
109
110 void RadixSortClearMat::_initDefines(ShaderDefines& defines)
111 {
112 initCommonDefines(defines);
113 }
114
115 void RadixSortClearMat::execute(const SPtr<GpuBuffer>& outputOffsets)
116 {
117 BS_RENMAT_PROFILE_BLOCK
118
119 mOutputParam.set(outputOffsets);
120
121 bind();
122 RenderAPI::instance().dispatchCompute(1);
123 }
124
125 RadixSortCountMat::RadixSortCountMat()
126 {
127 mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gInputKeys", mInputKeysParam);
128 mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gOutputCounts", mOutputCountsParam);
129 }
130
131 void RadixSortCountMat::_initDefines(ShaderDefines& defines)
132 {
133 initCommonDefines(defines);
134 }
135
136 void RadixSortCountMat::execute(UINT32 numGroups, const SPtr<GpuParamBlockBuffer>& params,
137 const SPtr<GpuBuffer>& inputKeys, const SPtr<GpuBuffer>& outputOffsets)
138 {
139 BS_RENMAT_PROFILE_BLOCK
140
141 mInputKeysParam.set(inputKeys);
142 mOutputCountsParam.set(outputOffsets);
143
144 mParams->setParamBlockBuffer("Params", params);
145
146 bind();
147 RenderAPI::instance().dispatchCompute(numGroups);
148 }
149
150 RadixSortPrefixScanMat::RadixSortPrefixScanMat()
151 {
152 mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gInputCounts", mInputCountsParam);
153 mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gOutputOffsets", mOutputOffsetsParam);
154 }
155
156 void RadixSortPrefixScanMat::_initDefines(ShaderDefines& defines)
157 {
158 initCommonDefines(defines);
159 }
160
161 void RadixSortPrefixScanMat::execute(const SPtr<GpuParamBlockBuffer>& params, const SPtr<GpuBuffer>& inputCounts,
162 const SPtr<GpuBuffer>& outputOffsets)
163 {
164 BS_RENMAT_PROFILE_BLOCK
165
166 mInputCountsParam.set(inputCounts);
167 mOutputOffsetsParam.set(outputOffsets);
168
169 mParams->setParamBlockBuffer("Params", params);
170
171 bind();
172 RenderAPI::instance().dispatchCompute(1);
173 }
174
175 RadixSortReorderMat::RadixSortReorderMat()
176 {
177 mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gInputOffsets", mInputOffsetsBufferParam);
178 mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gInputKeys", mInputKeysBufferParam);
179 mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gInputValues", mInputValuesBufferParam);
180 mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gOutputKeys", mOutputKeysBufferParam);
181 mParams->getBufferParam(GPT_COMPUTE_PROGRAM, "gOutputValues", mOutputValuesBufferParam);
182 }
183
184 void RadixSortReorderMat::_initDefines(ShaderDefines& defines)
185 {
186 initCommonDefines(defines);
187 }
188
189 void RadixSortReorderMat::execute(UINT32 numGroups, const SPtr<GpuParamBlockBuffer>& params,
190 const SPtr<GpuBuffer>& inputPrefix, const GpuSortBuffers& buffers, UINT32 inputBufferIdx)
191 {
192 BS_RENMAT_PROFILE_BLOCK
193
194 const UINT32 outputBufferIdx = (inputBufferIdx + 1) % 2;
195
196 mInputOffsetsBufferParam.set(inputPrefix);
197 mInputKeysBufferParam.set(buffers.keys[inputBufferIdx]);
198 mInputValuesBufferParam.set(buffers.values[inputBufferIdx]);
199 mOutputKeysBufferParam.set(buffers.keys[outputBufferIdx]);
200 mOutputValuesBufferParam.set(buffers.values[outputBufferIdx]);
201
202 mParams->setParamBlockBuffer("Params", params);
203
204 bind();
205 RenderAPI::instance().dispatchCompute(numGroups);
206 }
207
208 GpuSort::GpuSort()
209 {
210 mHelperBuffers[0] = createHelperBuffer();
211 mHelperBuffers[1] = createHelperBuffer();
212 }
213
214 UINT32 GpuSort::sort(const GpuSortBuffers& buffers, UINT32 numKeys, UINT32 keyMask)
215 {
216 // Nothing to do if no input or output key buffers
217 if(buffers.keys[0] == nullptr || buffers.keys[1] == nullptr)
218 return 0;
219
220 // Check if all buffers have been created with required options
221 const char* errorMsg = nullptr;
222 for(UINT32 i = 0; i < 2; i++)
223 {
224 errorMsg = checkSortBuffer(*buffers.keys[i]);
225 if(errorMsg) break;
226
227 if(buffers.values[i])
228 {
229 errorMsg = checkSortBuffer(*buffers.values[i]);
230 if(errorMsg) break;
231 }
232 }
233
234 if(errorMsg)
235 {
236 LOGERR("GpuSort failed: " + String(errorMsg));
237 return 0;
238 }
239
240 // Check if all buffers have the same size
241 bool validSize = buffers.keys[0]->getSize() == buffers.keys[1]->getSize();
242 if(buffers.values[0] && buffers.values[1])
243 {
244 validSize = buffers.keys[0]->getSize() == buffers.values[0]->getSize() &&
245 buffers.keys[0]->getSize() == buffers.values[1]->getSize();
246
247 }
248
249 if (!validSize)
250 {
251 LOGERR("GpuSort failed: All sort buffers must have the same size.");
252 return 0;
253 }
254
255 const GpuSortProperties gpuSortProps(numKeys);
256 SPtr<GpuParamBlockBuffer> params = createGpuSortParams(gpuSortProps);
257
258 UINT32 bitOffset = 0;
259 UINT32 inputBufferIdx = 0;
260 for(UINT32 i = 0; i < NUM_PASSES; i++)
261 {
262 if(((KEY_MASK << bitOffset) & keyMask) != 0)
263 {
264 gRadixSortParamsDef.gBitOffset.set(params, bitOffset);
265
266 RadixSortClearMat::get()->execute(mHelperBuffers[0]);
267 RadixSortCountMat::get()->execute(gpuSortProps.numGroups, params, buffers.keys[inputBufferIdx], mHelperBuffers[0]);
268 RadixSortPrefixScanMat::get()->execute(params, mHelperBuffers[0], mHelperBuffers[1]);
269 RadixSortReorderMat::get()->execute(gpuSortProps.numGroups, params, mHelperBuffers[1], buffers, inputBufferIdx);
270
271 inputBufferIdx = (inputBufferIdx + 1) % 2;
272 }
273
274 bitOffset += RADIX_NUM_BITS;
275 }
276
277 return inputBufferIdx;
278 }
279
280 GpuSortBuffers GpuSort::createSortBuffers(UINT32 numElements, bool values)
281 {
282 GpuSortBuffers output;
283
284 GPU_BUFFER_DESC bufferDesc;
285 bufferDesc.elementCount = numElements;
286 bufferDesc.format = BF_32X1U;
287 bufferDesc.type = GBT_STANDARD;
288 bufferDesc.usage = GBU_LOADSTORE;
289
290 output.keys[0] = GpuBuffer::create(bufferDesc);
291 output.keys[1] = GpuBuffer::create(bufferDesc);
292
293 if(values)
294 {
295 output.values[0] = GpuBuffer::create(bufferDesc);
296 output.values[1] = GpuBuffer::create(bufferDesc);
297 }
298
299 return output;
300 }
301
302 // Note: This test isn't currently hooked up anywhere. It might be a good idea to set it up as a unit test, but it would
303 // require exposing parts of GpuSort to the public, which I don't feel like it's worth doing just for a test. So instead
304 // just make sure to run the test below if you modify any of the GpuSort code.
305 void runSortTest()
306 {
307 // Generate test keys
308 static constexpr UINT32 NUM_INPUT_KEYS = 10000;
309 Vector<UINT32> inputKeys;
310 inputKeys.reserve(NUM_INPUT_KEYS);
311
312 Random random;
313 for(UINT32 i = 0; i < NUM_INPUT_KEYS; i++)
314 inputKeys.push_back(random.getRange(0, 15) << 4 | std::min(NUM_DIGITS - 1, (i / (NUM_INPUT_KEYS / 16))));
315
316 const auto count = (UINT32)inputKeys.size();
317 UINT32 bitOffset = 4;
318 UINT32 bitMask = (1 << RADIX_NUM_BITS) - 1;
319
320 // Prepare buffers
321 const GpuSortProperties gpuSortProps(count);
322 SPtr<GpuParamBlockBuffer> params = createGpuSortParams(gpuSortProps);
323
324 gRadixSortParamsDef.gBitOffset.set(params, bitOffset);
325
326 GpuSortBuffers sortBuffers = GpuSort::createSortBuffers(count);
327 sortBuffers.keys[0]->writeData(0, sortBuffers.keys[0]->getSize(), inputKeys.data(), BWT_DISCARD);
328
329 SPtr<GpuBuffer> helperBuffers[2];
330 helperBuffers[0] = createHelperBuffer();
331 helperBuffers[1] = createHelperBuffer();
332
333 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
334 //////////////////////////////////////////// Count keys per group //////////////////////////////////////////////////
335 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
336
337 // SERIAL:
338 Vector<UINT32> counts(gpuSortProps.numGroups * NUM_DIGITS);
339 for(UINT32 groupIdx = 0; groupIdx < gpuSortProps.numGroups; groupIdx++)
340 {
341 // Count keys per thread
342 UINT32 localCounts[NUM_THREADS * NUM_DIGITS] = { 0 };
343
344 UINT32 tileIdx;
345 UINT32 numTiles;
346 if(groupIdx < gpuSortProps.extraTiles)
347 {
348 numTiles = gpuSortProps.tilesPerGroup + 1;
349 tileIdx = groupIdx * numTiles;
350 }
351 else
352 {
353 numTiles = gpuSortProps.tilesPerGroup;
354 tileIdx = groupIdx * numTiles + gpuSortProps.extraTiles;
355 }
356
357 UINT32 keyBegin = tileIdx * TILE_SIZE;
358 UINT32 keyEnd = keyBegin + numTiles * TILE_SIZE;
359
360 while(keyBegin < keyEnd)
361 {
362 for(UINT32 threadIdx = 0; threadIdx < NUM_THREADS; threadIdx++)
363 {
364 UINT32 key = inputKeys[keyBegin + threadIdx];
365 UINT32 digit = (key >> bitOffset) & bitMask;
366
367 localCounts[threadIdx * NUM_DIGITS + digit] += 1;
368 }
369
370 keyBegin += NUM_THREADS;
371 }
372
373 if(groupIdx == (gpuSortProps.numGroups - 1))
374 {
375 keyBegin = keyEnd;
376 keyEnd = keyBegin + gpuSortProps.extraKeys;
377
378 while(keyBegin < keyEnd)
379 {
380 for (UINT32 threadIdx = 0; threadIdx < NUM_THREADS; threadIdx++)
381 {
382 if((keyBegin + threadIdx) < keyEnd)
383 {
384 UINT32 key = inputKeys[keyBegin + threadIdx];
385 UINT32 digit = (key >> bitOffset) & bitMask;
386
387 localCounts[threadIdx * NUM_DIGITS + digit] += 1;
388 }
389 }
390
391 keyBegin += NUM_THREADS;
392 }
393 }
394
395 // Sum up all key counts in a group
396 static constexpr UINT32 NUM_REDUCE_THREADS = 64;
397 static constexpr UINT32 NUM_REDUCE_THREADS_PER_DIGIT = NUM_REDUCE_THREADS / NUM_DIGITS;
398 static constexpr UINT32 NUM_REDUCE_ELEMS_PER_THREAD_PER_DIGIT = NUM_THREADS / NUM_REDUCE_THREADS_PER_DIGIT;
399
400 UINT32 reduceCounters[NUM_REDUCE_THREADS] = { 0 };
401 UINT32 reduceTotals[NUM_REDUCE_THREADS] = { 0 };
402 for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++)
403 {
404 if(threadId < NUM_REDUCE_THREADS)
405 {
406 UINT32 digitIdx = threadId / NUM_REDUCE_THREADS_PER_DIGIT;
407 UINT32 setIdx = threadId & (NUM_REDUCE_THREADS_PER_DIGIT - 1);
408
409 UINT32 total = 0;
410 for(UINT32 i = 0; i < NUM_REDUCE_ELEMS_PER_THREAD_PER_DIGIT; i++)
411 {
412 UINT32 threadIdx = (setIdx * NUM_REDUCE_ELEMS_PER_THREAD_PER_DIGIT + i) * NUM_DIGITS;
413 total += localCounts[threadIdx + digitIdx];
414 }
415
416 reduceCounters[digitIdx * NUM_REDUCE_THREADS_PER_DIGIT + setIdx] = total;
417 reduceTotals[threadId] = total;
418
419 }
420 }
421
422 // And do parallel reduction on the result of serial additions
423 for (UINT32 i = 1; i < NUM_REDUCE_THREADS_PER_DIGIT; i <<= 1)
424 {
425 for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++)
426 {
427 if (threadId < NUM_REDUCE_THREADS)
428 {
429 UINT32 digitIdx = threadId / NUM_REDUCE_THREADS_PER_DIGIT;
430 UINT32 setIdx = threadId & (NUM_REDUCE_THREADS_PER_DIGIT - 1);
431
432 reduceTotals[threadId] += reduceCounters[digitIdx * NUM_REDUCE_THREADS_PER_DIGIT + setIdx + i];
433 }
434 }
435
436 for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++)
437 {
438 if (threadId < NUM_REDUCE_THREADS)
439 {
440 UINT32 digitIdx = threadId / NUM_REDUCE_THREADS_PER_DIGIT;
441 UINT32 setIdx = threadId & (NUM_REDUCE_THREADS_PER_DIGIT - 1);
442
443 reduceCounters[digitIdx * NUM_REDUCE_THREADS_PER_DIGIT + setIdx] = reduceTotals[threadId];
444 }
445 }
446 }
447
448 for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++)
449 {
450 if(threadId < NUM_DIGITS)
451 counts[groupIdx * NUM_DIGITS + threadId] = reduceCounters[threadId * NUM_REDUCE_THREADS_PER_DIGIT];
452 }
453 }
454
455 // PARALLEL:
456 RadixSortClearMat::get()->execute(helperBuffers[0]);
457 RadixSortCountMat::get()->execute(gpuSortProps.numGroups, params, sortBuffers.keys[0], helperBuffers[0]);
458 RenderAPI::instance().submitCommandBuffer(nullptr);
459
460 // Compare with GPU count
461 const UINT32 helperBufferLength = helperBuffers[0]->getProperties().getElementCount();
462 Vector<UINT32> bufferCounts(helperBufferLength);
463 helperBuffers[0]->readData(0, helperBufferLength * sizeof(UINT32), bufferCounts.data());
464
465 for(UINT32 i = 0; i < (UINT32)counts.size(); i++)
466 assert(bufferCounts[i] == counts[i]);
467
468 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
469 /////////////////////////////////////////////// Calculate offsets //////////////////////////////////////////////////
470 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
471
472 // SERIAL:
473 // Prefix sum over per-digit counts over all groups
474 Vector<UINT32> perDigitPrefixSum(NUM_DIGITS * MAX_NUM_GROUPS);
475 for(UINT32 groupIdx = 0; groupIdx < gpuSortProps.numGroups; groupIdx++)
476 {
477 for (UINT32 j = 0; j < NUM_DIGITS; j++)
478 perDigitPrefixSum[groupIdx * NUM_DIGITS + j] = counts[groupIdx * NUM_DIGITS + j];
479 }
480
481 // Prefix sum over per-digit counts over all groups
482 //// Upsweep
483 UINT32 offset = 1;
484 for (UINT32 i = MAX_NUM_GROUPS >> 1; i > 0; i >>= 1)
485 {
486 for (UINT32 groupIdx = 0; groupIdx < MAX_NUM_GROUPS; groupIdx++)
487 {
488 if (groupIdx < i)
489 {
490 for (UINT32 j = 0; j < NUM_DIGITS; j++)
491 {
492 UINT32 idx0 = (offset * (2 * groupIdx + 1) - 1) * NUM_DIGITS + j;
493 UINT32 idx1 = (offset * (2 * groupIdx + 2) - 1) * NUM_DIGITS + j;
494
495 perDigitPrefixSum[idx1] += perDigitPrefixSum[idx0];
496 }
497 }
498 }
499
500 offset <<= 1;
501 }
502
503 //// Downsweep
504 UINT32 totalsPrefixSum[NUM_DIGITS] = { 0 };
505 for(UINT32 groupIdx = 0; groupIdx < NUM_DIGITS; groupIdx++)
506 {
507 if (groupIdx < NUM_DIGITS)
508 {
509 UINT32 idx = (MAX_NUM_GROUPS - 1) * NUM_DIGITS + groupIdx;
510 totalsPrefixSum[groupIdx] = perDigitPrefixSum[idx];
511 perDigitPrefixSum[idx] = 0;
512 }
513 }
514
515 for (UINT32 i = 1; i < MAX_NUM_GROUPS; i <<= 1)
516 {
517 offset >>= 1;
518
519 for (UINT32 groupIdx = 0; groupIdx < MAX_NUM_GROUPS; groupIdx++)
520 {
521 if (groupIdx < i)
522 {
523 for (UINT32 j = 0; j < NUM_DIGITS; j++)
524 {
525 UINT32 idx0 = (offset * (2 * groupIdx + 1) - 1) * NUM_DIGITS + j;
526 UINT32 idx1 = (offset * (2 * groupIdx + 2) - 1) * NUM_DIGITS + j;
527
528 UINT32 temp = perDigitPrefixSum[idx0];
529 perDigitPrefixSum[idx0] = perDigitPrefixSum[idx1];
530 perDigitPrefixSum[idx1] += temp;
531 }
532 }
533 }
534 }
535
536 // Prefix sum over the total count
537 for(UINT32 i = 1; i < NUM_DIGITS; i++)
538 totalsPrefixSum[i] += totalsPrefixSum[i - 1];
539
540 // Make it exclusive by shifting
541 for(UINT32 i = NUM_DIGITS - 1; i > 0; i--)
542 totalsPrefixSum[i] = totalsPrefixSum[i - 1];
543
544 totalsPrefixSum[0] = 0;
545
546 Vector<UINT32> offsets(gpuSortProps.numGroups * NUM_DIGITS);
547 for (UINT32 groupIdx = 0; groupIdx < gpuSortProps.numGroups; groupIdx++)
548 {
549 for (UINT32 i = 0; i < NUM_DIGITS; i++)
550 offsets[groupIdx * NUM_DIGITS + i] = totalsPrefixSum[i] + perDigitPrefixSum[groupIdx * NUM_DIGITS + i];
551 }
552
553 // PARALLEL:
554 RadixSortPrefixScanMat::get()->execute(params, helperBuffers[0], helperBuffers[1]);
555 RenderAPI::instance().submitCommandBuffer(nullptr);
556
557 // Compare with GPU offsets
558 Vector<UINT32> bufferOffsets(helperBufferLength);
559 helperBuffers[1]->readData(0, helperBufferLength * sizeof(UINT32), bufferOffsets.data());
560
561 for(UINT32 i = 0; i < (UINT32)offsets.size(); i++)
562 assert(bufferOffsets[i] == offsets[i]);
563
564 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
565 /////////////////////////////////////////////////// Reorder ////////////////////////////////////////////////////////
566 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
567
568 // SERIAL:
569 // Reorder within each tile
570 Vector<UINT32> sortedKeys(inputKeys.size());
571 UINT32 sGroupOffsets[NUM_DIGITS];
572 UINT32 sLocalScratch[NUM_DIGITS * NUM_THREADS];
573 UINT32 sTileTotals[NUM_DIGITS];
574 UINT32 sCurrentTileTotal[NUM_DIGITS];
575
576 for(UINT32 groupIdx = 0; groupIdx < gpuSortProps.numGroups; groupIdx++)
577 {
578 for(UINT32 i = 0; i < NUM_DIGITS; i++)
579 {
580 // Load offsets for this group to local memory
581 sGroupOffsets[i] = offsets[groupIdx * NUM_DIGITS + i];
582
583 // Clear tile totals
584 sTileTotals[i] = 0;
585 }
586
587 // Handle case when number of tiles isn't exactly divisible by number of groups, in
588 // which case first N groups handle those extra tiles
589 UINT32 tileIdx;
590 UINT32 numTiles;
591 if(groupIdx < gpuSortProps.extraTiles)
592 {
593 numTiles = gpuSortProps.tilesPerGroup + 1;
594 tileIdx = groupIdx * numTiles;
595 }
596 else
597 {
598 numTiles = gpuSortProps.tilesPerGroup;
599 tileIdx = groupIdx * numTiles + gpuSortProps.extraTiles;
600 }
601
602 // We need to generate per-thread offsets (prefix sum) of where to store the keys at
603 // (This is equivalent to what was done in count & prefix sum shaders, except that was done per-group)
604
605 //// First, count all digits
606 UINT32 keyBegin[NUM_THREADS];
607 for(UINT32 threadId = 0; threadId < NUM_THREADS; threadId++)
608 keyBegin[threadId] = tileIdx * TILE_SIZE;
609
610 auto prefixSum = [&sLocalScratch, &sCurrentTileTotal]()
611 {
612 // Upsweep to generate partial sums
613 UINT32 offsets[NUM_THREADS];
614 for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++)
615 offsets[threadId] = 1;
616
617 for (UINT32 i = NUM_THREADS >> 1; i > 0; i >>= 1)
618 {
619 for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++)
620 {
621 if (threadId < i)
622 {
623 // Note: If I run more than NUM_THREADS threads I wouldn't have to
624 // iterate over all digits in a single thread
625 // Note: Perhaps run part of this step serially for better performance
626 for (UINT32 j = 0; j < NUM_DIGITS; j++)
627 {
628 UINT32 idx0 = (offsets[threadId] * (2 * threadId + 1) - 1) * NUM_DIGITS + j;
629 UINT32 idx1 = (offsets[threadId] * (2 * threadId + 2) - 1) * NUM_DIGITS + j;
630
631 // Note: Check and remove bank conflicts
632 sLocalScratch[idx1] += sLocalScratch[idx0];
633 }
634 }
635
636 offsets[threadId] <<= 1;
637 }
638 }
639
640 // Set tree roots to zero (prepare for downsweep)
641 for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++)
642 {
643 if (threadId < NUM_DIGITS)
644 {
645 UINT32 idx = (NUM_THREADS - 1) * NUM_DIGITS + threadId;
646 sCurrentTileTotal[threadId] = sLocalScratch[idx];
647
648 sLocalScratch[idx] = 0;
649 }
650 }
651
652 // Downsweep to calculate the prefix sum from partial sums that were generated
653 // during upsweep
654 for (UINT32 i = 1; i < NUM_THREADS; i <<= 1)
655 {
656 for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++)
657 {
658 offsets[threadId] >>= 1;
659
660 if (threadId < i)
661 {
662 for (UINT32 j = 0; j < NUM_DIGITS; j++)
663 {
664 UINT32 idx0 = (offsets[threadId] * (2 * threadId + 1) - 1) * NUM_DIGITS + j;
665 UINT32 idx1 = (offsets[threadId] * (2 * threadId + 2) - 1) * NUM_DIGITS + j;
666
667 // Note: Check and resolve bank conflicts
668 UINT32 temp = sLocalScratch[idx0];
669 sLocalScratch[idx0] = sLocalScratch[idx1];
670 sLocalScratch[idx1] += temp;
671 }
672 }
673 }
674 }
675 };
676
677 for(UINT32 tileIdx = 0; tileIdx < numTiles; tileIdx++)
678 {
679 // Zero out local counter
680 for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++)
681 for (UINT32 i = 0; i < NUM_DIGITS; i++)
682 sLocalScratch[i * NUM_THREADS + threadId] = 0;
683
684 for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++)
685 {
686 for (UINT32 i = 0; i < KEYS_PER_LOOP; i++)
687 {
688 UINT32 idx = keyBegin[threadId] + threadId * KEYS_PER_LOOP + i;
689 UINT32 key = inputKeys[idx];
690 UINT32 digit = (key >> bitOffset) & KEY_MASK;
691
692 sLocalScratch[threadId * NUM_DIGITS + digit] += 1;
693 }
694 }
695
696 prefixSum();
697
698 // Actually re-order the keys
699 for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++)
700 {
701 UINT32 localOffsets[NUM_DIGITS];
702 for (UINT32 i = 0; i < NUM_DIGITS; i++)
703 localOffsets[i] = 0;
704
705 for (UINT32 i = 0; i < KEYS_PER_LOOP; i++)
706 {
707 UINT32 idx = keyBegin[threadId] + threadId * KEYS_PER_LOOP + i;
708 UINT32 key = inputKeys[idx];
709 UINT32 digit = (key >> bitOffset) & KEY_MASK;
710
711 UINT32 offset = sGroupOffsets[digit] + sTileTotals[digit] + sLocalScratch[threadId * NUM_DIGITS + digit] + localOffsets[digit];
712 localOffsets[digit]++;
713
714 // Note: First write to local memory then attempt to coalesce when writing to global?
715 sortedKeys[offset] = key;
716 }
717 }
718
719 // Update tile totals
720 for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++)
721 {
722 if (threadId < NUM_DIGITS)
723 sTileTotals[threadId] += sCurrentTileTotal[threadId];
724 }
725
726 for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++)
727 keyBegin[threadId] += TILE_SIZE;
728 }
729
730 if (groupIdx == (gpuSortProps.numGroups - 1) && gpuSortProps.extraKeys > 0)
731 {
732 // Zero out local counter
733 for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++)
734 for (UINT32 i = 0; i < NUM_DIGITS; i++)
735 sLocalScratch[i * NUM_THREADS + threadId] = 0;
736
737 for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++)
738 {
739 for (UINT32 i = 0; i < KEYS_PER_LOOP; i++)
740 {
741 UINT32 localIdx = threadId * KEYS_PER_LOOP + i;
742
743 if (localIdx >= gpuSortProps.extraKeys)
744 continue;
745
746 UINT32 idx = keyBegin[threadId] + localIdx;
747 UINT32 key = inputKeys[idx];
748 UINT32 digit = (key >> bitOffset) & KEY_MASK;
749
750 sLocalScratch[threadId * NUM_DIGITS + digit] += 1;
751 }
752 }
753
754 prefixSum();
755
756 // Actually re-order the keys
757 for (UINT32 threadId = 0; threadId < NUM_THREADS; threadId++)
758 {
759 UINT32 localOffsets[NUM_DIGITS];
760 for (UINT32 i = 0; i < NUM_DIGITS; i++)
761 localOffsets[i] = 0;
762
763 for (UINT32 i = 0; i < KEYS_PER_LOOP; i++)
764 {
765 UINT32 localIdx = threadId * KEYS_PER_LOOP + i;
766
767 if (localIdx >= gpuSortProps.extraKeys)
768 continue;
769
770 UINT32 idx = keyBegin[threadId] + localIdx;
771 UINT32 key = inputKeys[idx];
772 UINT32 digit = (key >> bitOffset) & KEY_MASK;
773
774 UINT32 offset = sGroupOffsets[digit] + sTileTotals[digit] + sLocalScratch[threadId * NUM_DIGITS + digit] + localOffsets[digit];
775 localOffsets[digit]++;
776
777 // Note: First write to local memory then attempt to coalesce when writing to global?
778 sortedKeys[offset] = key;
779 }
780 }
781 }
782 }
783
784 // PARALLEL:
785 RadixSortReorderMat::get()->execute(gpuSortProps.numGroups, params, helperBuffers[1], sortBuffers, 0);
786 RenderAPI::instance().submitCommandBuffer(nullptr);
787
788 // Compare with GPU keys
789 Vector<UINT32> bufferSortedKeys(count);
790 sortBuffers.keys[1]->readData(0, count * sizeof(UINT32), bufferSortedKeys.data());
791
792 for(UINT32 i = 0; i < count; i++)
793 assert(bufferSortedKeys[i] == sortedKeys[i]);
794
795 // Ensure everything is actually sorted
796 assert(std::is_sorted(sortedKeys.begin(), sortedKeys.end()));
797 }
798}}
799