1#pragma once
2
3
4#include <string.h>
5#if !defined(__APPLE__) && !defined(__FreeBSD__)
6#include <malloc.h>
7#endif
8#include <algorithm>
9#include <cmath>
10#include <cstdlib>
11#include <cstdint>
12#include <type_traits>
13
14#include <ext/bit_cast.h>
15#include <Core/Types.h>
16#include <Core/Defines.h>
17
18
19/** Radix sort, has the following functionality:
20 * Can sort unsigned, signed numbers, and floats.
21 * Can sort an array of fixed length elements that contain something else besides the key.
22 * Customizable radix size.
23 *
24 * LSB, stable.
25 * NOTE For some applications it makes sense to add MSB-radix-sort,
26 * as well as radix-select, radix-partial-sort, radix-get-permutation algorithms based on it.
27 */
28
29
30/** Used as a template parameter. See below.
31 */
32struct RadixSortMallocAllocator
33{
34 void * allocate(size_t size)
35 {
36 return malloc(size);
37 }
38
39 void deallocate(void * ptr, size_t /*size*/)
40 {
41 return free(ptr);
42 }
43};
44
45
46/** A transformation that transforms the bit representation of a key into an unsigned integer number,
47 * that the order relation over the keys will match the order relation over the obtained unsigned numbers.
48 * For floats this conversion does the following:
49 * if the signed bit is set, it flips all other bits.
50 * In this case, NaN-s are bigger than all normal numbers.
51 */
52template <typename KeyBits>
53struct RadixSortFloatTransform
54{
55 /// Is it worth writing the result in memory, or is it better to do calculation every time again?
56 static constexpr bool transform_is_simple = false;
57
58 static KeyBits forward(KeyBits x)
59 {
60 return x ^ ((-(x >> (sizeof(KeyBits) * 8 - 1))) | (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)));
61 }
62
63 static KeyBits backward(KeyBits x)
64 {
65 return x ^ (((x >> (sizeof(KeyBits) * 8 - 1)) - 1) | (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)));
66 }
67};
68
69
70template <typename TElement>
71struct RadixSortFloatTraits
72{
73 using Element = TElement; /// The type of the element. It can be a structure with a key and some other payload. Or just a key.
74 using Key = Element; /// The key to sort by.
75 using CountType = uint32_t; /// Type for calculating histograms. In the case of a known small number of elements, it can be less than size_t.
76
77 /// The type to which the key is transformed to do bit operations. This UInt is the same size as the key.
78 using KeyBits = std::conditional_t<sizeof(Key) == 8, uint64_t, uint32_t>;
79
80 static constexpr size_t PART_SIZE_BITS = 8; /// With what pieces of the key, in bits, to do one pass - reshuffle of the array.
81
82 /// Converting a key into KeyBits is such that the order relation over the key corresponds to the order relation over KeyBits.
83 using Transform = RadixSortFloatTransform<KeyBits>;
84
85 /// An object with the functions allocate and deallocate.
86 /// Can be used, for example, to allocate memory for a temporary array on the stack.
87 /// To do this, the allocator itself is created on the stack.
88 using Allocator = RadixSortMallocAllocator;
89
90 /// The function to get the key from an array element.
91 static Key & extractKey(Element & elem) { return elem; }
92
93 /// Used when fallback to comparison based sorting is needed.
94 /// TODO: Correct handling of NaNs, NULLs, etc
95 static bool less(Key x, Key y)
96 {
97 return x < y;
98 }
99};
100
101
102template <typename KeyBits>
103struct RadixSortIdentityTransform
104{
105 static constexpr bool transform_is_simple = true;
106
107 static KeyBits forward(KeyBits x) { return x; }
108 static KeyBits backward(KeyBits x) { return x; }
109};
110
111
112
113template <typename TElement>
114struct RadixSortUIntTraits
115{
116 using Element = TElement;
117 using Key = Element;
118 using CountType = uint32_t;
119 using KeyBits = Key;
120
121 static constexpr size_t PART_SIZE_BITS = 8;
122
123 using Transform = RadixSortIdentityTransform<KeyBits>;
124 using Allocator = RadixSortMallocAllocator;
125
126 static Key & extractKey(Element & elem) { return elem; }
127
128 static bool less(Key x, Key y)
129 {
130 return x < y;
131 }
132};
133
134
135template <typename KeyBits>
136struct RadixSortSignedTransform
137{
138 static constexpr bool transform_is_simple = true;
139
140 static KeyBits forward(KeyBits x) { return x ^ (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)); }
141 static KeyBits backward(KeyBits x) { return x ^ (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)); }
142};
143
144
145template <typename TElement>
146struct RadixSortIntTraits
147{
148 using Element = TElement;
149 using Key = Element;
150 using CountType = uint32_t;
151 using KeyBits = std::make_unsigned_t<Key>;
152
153 static constexpr size_t PART_SIZE_BITS = 8;
154
155 using Transform = RadixSortSignedTransform<KeyBits>;
156 using Allocator = RadixSortMallocAllocator;
157
158 static Key & extractKey(Element & elem) { return elem; }
159
160 static bool less(Key x, Key y)
161 {
162 return x < y;
163 }
164};
165
166
167template <typename T>
168using RadixSortNumTraits = std::conditional_t<
169 is_integral_v<T>,
170 std::conditional_t<is_unsigned_v<T>, RadixSortUIntTraits<T>, RadixSortIntTraits<T>>,
171 RadixSortFloatTraits<T>>;
172
173
174template <typename Traits>
175struct RadixSort
176{
177private:
178 using Element = typename Traits::Element;
179 using Key = typename Traits::Key;
180 using CountType = typename Traits::CountType;
181 using KeyBits = typename Traits::KeyBits;
182
183 // Use insertion sort if the size of the array is less than equal to this threshold
184 static constexpr size_t INSERTION_SORT_THRESHOLD = 64;
185
186 static constexpr size_t HISTOGRAM_SIZE = 1 << Traits::PART_SIZE_BITS;
187 static constexpr size_t PART_BITMASK = HISTOGRAM_SIZE - 1;
188 static constexpr size_t KEY_BITS = sizeof(Key) * 8;
189 static constexpr size_t NUM_PASSES = (KEY_BITS + (Traits::PART_SIZE_BITS - 1)) / Traits::PART_SIZE_BITS;
190
191 static ALWAYS_INLINE KeyBits getPart(size_t N, KeyBits x)
192 {
193 if (Traits::Transform::transform_is_simple)
194 x = Traits::Transform::forward(x);
195
196 return (x >> (N * Traits::PART_SIZE_BITS)) & PART_BITMASK;
197 }
198
199 static KeyBits keyToBits(Key x) { return ext::bit_cast<KeyBits>(x); }
200 static Key bitsToKey(KeyBits x) { return ext::bit_cast<Key>(x); }
201
202 static void insertionSortInternal(Element *arr, size_t size)
203 {
204 Element * end = arr + size;
205 for (Element * i = arr + 1; i < end; ++i)
206 {
207 if (Traits::less(Traits::extractKey(*i), Traits::extractKey(*(i - 1))))
208 {
209 Element * j;
210 Element tmp = *i;
211 *i = *(i - 1);
212 for (j = i - 1; j > arr && Traits::less(Traits::extractKey(tmp), Traits::extractKey(*(j - 1))); --j)
213 *j = *(j - 1);
214 *j = tmp;
215 }
216 }
217 }
218
219 /* Main MSD radix sort subroutine
220 * Puts elements to buckets based on PASS-th digit, then recursively calls insertion sort or itself on the buckets
221 */
222 template <size_t PASS>
223 static inline void radixSortMSDInternal(Element * arr, size_t size, size_t limit)
224 {
225 Element * last_list[HISTOGRAM_SIZE + 1];
226 Element ** last = last_list + 1;
227 size_t count[HISTOGRAM_SIZE] = {0};
228
229 for (Element * i = arr; i < arr + size; ++i)
230 ++count[getPart(PASS, *i)];
231
232 last_list[0] = last_list[1] = arr;
233
234 size_t buckets_for_recursion = HISTOGRAM_SIZE;
235 Element * finish = arr + size;
236 for (size_t i = 1; i < HISTOGRAM_SIZE; ++i)
237 {
238 last[i] = last[i - 1] + count[i - 1];
239 if (last[i] >= arr + limit)
240 {
241 buckets_for_recursion = i;
242 finish = last[i];
243 }
244 }
245
246 /* At this point, we have the following variables:
247 * count[i] is the size of i-th bucket
248 * last[i] is a pointer to the beginning of i-th bucket, last[-1] == last[0]
249 * buckets_for_recursion is the number of buckets that should be sorted, the last of them only partially
250 * finish is a pointer to the end of the first buckets_for_recursion buckets
251 */
252
253 // Scatter array elements to buckets until the first buckets_for_recursion buckets are full
254 for (size_t i = 0; i < buckets_for_recursion; ++i)
255 {
256 Element * end = last[i - 1] + count[i];
257 if (end == finish)
258 {
259 last[i] = end;
260 break;
261 }
262 while (last[i] != end)
263 {
264 Element swapper = *last[i];
265 KeyBits tag = getPart(PASS, swapper);
266 if (tag != i)
267 {
268 do
269 {
270 std::swap(swapper, *last[tag]++);
271 } while ((tag = getPart(PASS, swapper)) != i);
272 *last[i] = swapper;
273 }
274 ++last[i];
275 }
276 }
277
278 if constexpr (PASS > 0)
279 {
280 // Recursively sort buckets, except the last one
281 for (size_t i = 0; i < buckets_for_recursion - 1; ++i)
282 {
283 Element * start = last[i - 1];
284 size_t subsize = last[i] - last[i - 1];
285 radixSortMSDInternalHelper<PASS - 1>(start, subsize, subsize);
286 }
287
288 // Sort last necessary bucket with limit
289 Element * start = last[buckets_for_recursion - 2];
290 size_t subsize = last[buckets_for_recursion - 1] - last[buckets_for_recursion - 2];
291 size_t sublimit = limit - (last[buckets_for_recursion - 1] - arr);
292 radixSortMSDInternalHelper<PASS - 1>(start, subsize, sublimit);
293 }
294 }
295
296 // A helper to choose sorting algorithm based on array length
297 template <size_t PASS>
298 static inline void radixSortMSDInternalHelper(Element * arr, size_t size, size_t limit)
299 {
300 if (size <= INSERTION_SORT_THRESHOLD)
301 insertionSortInternal(arr, size);
302 else
303 radixSortMSDInternal<PASS>(arr, size, limit);
304 }
305
306public:
307 /// Least significant digit radix sort (stable)
308 static void executeLSD(Element * arr, size_t size)
309 {
310 /// If the array is smaller than 256, then it is better to use another algorithm.
311
312 /// There are loops of NUM_PASSES. It is very important that they are unfolded at compile-time.
313
314 /// For each of the NUM_PASSES bit ranges of the key, consider how many times each value of this bit range met.
315 CountType histograms[HISTOGRAM_SIZE * NUM_PASSES] = {0};
316
317 typename Traits::Allocator allocator;
318
319 /// We will do several passes through the array. On each pass, the data is transferred to another array. Let's allocate this temporary array.
320 Element * swap_buffer = reinterpret_cast<Element *>(allocator.allocate(size * sizeof(Element)));
321
322 /// Transform the array and calculate the histogram.
323 /// NOTE This is slightly suboptimal. Look at https://github.com/powturbo/TurboHist
324 for (size_t i = 0; i < size; ++i)
325 {
326 if (!Traits::Transform::transform_is_simple)
327 Traits::extractKey(arr[i]) = bitsToKey(Traits::Transform::forward(keyToBits(Traits::extractKey(arr[i]))));
328
329 for (size_t pass = 0; pass < NUM_PASSES; ++pass)
330 ++histograms[pass * HISTOGRAM_SIZE + getPart(pass, keyToBits(Traits::extractKey(arr[i])))];
331 }
332
333 {
334 /// Replace the histograms with the accumulated sums: the value in position i is the sum of the previous positions minus one.
335 size_t sums[NUM_PASSES] = {0};
336
337 for (size_t i = 0; i < HISTOGRAM_SIZE; ++i)
338 {
339 for (size_t pass = 0; pass < NUM_PASSES; ++pass)
340 {
341 size_t tmp = histograms[pass * HISTOGRAM_SIZE + i] + sums[pass];
342 histograms[pass * HISTOGRAM_SIZE + i] = sums[pass] - 1;
343 sums[pass] = tmp;
344 }
345 }
346 }
347
348 /// Move the elements in the order starting from the least bit piece, and then do a few passes on the number of pieces.
349 for (size_t pass = 0; pass < NUM_PASSES; ++pass)
350 {
351 Element * writer = pass % 2 ? arr : swap_buffer;
352 Element * reader = pass % 2 ? swap_buffer : arr;
353
354 for (size_t i = 0; i < size; ++i)
355 {
356 size_t pos = getPart(pass, keyToBits(Traits::extractKey(reader[i])));
357
358 /// Place the element on the next free position.
359 auto & dest = writer[++histograms[pass * HISTOGRAM_SIZE + pos]];
360 dest = reader[i];
361
362 /// On the last pass, we do the reverse transformation.
363 if (!Traits::Transform::transform_is_simple && pass == NUM_PASSES - 1)
364 Traits::extractKey(dest) = bitsToKey(Traits::Transform::backward(keyToBits(Traits::extractKey(reader[i]))));
365 }
366 }
367
368 /// If the number of passes is odd, the result array is in a temporary buffer. Copy it to the place of the original array.
369 /// NOTE Sometimes it will be more optimal to provide non-destructive interface, that will not modify original array.
370 if (NUM_PASSES % 2)
371 memcpy(arr, swap_buffer, size * sizeof(Element));
372
373 allocator.deallocate(swap_buffer, size * sizeof(Element));
374 }
375
376 /* Most significant digit radix sort
377 * Usually slower than LSD and is not stable, but allows partial sorting
378 *
379 * Based on https://github.com/voutcn/kxsort, license:
380 * The MIT License
381 * Copyright (c) 2016 Dinghua Li <voutcn@gmail.com>
382 *
383 * Permission is hereby granted, free of charge, to any person obtaining
384 * a copy of this software and associated documentation files (the
385 * "Software"), to deal in the Software without restriction, including
386 * without limitation the rights to use, copy, modify, merge, publish,
387 * distribute, sublicense, and/or sell copies of the Software, and to
388 * permit persons to whom the Software is furnished to do so, subject to
389 * the following conditions:
390 *
391 * The above copyright notice and this permission notice shall be
392 * included in all copies or substantial portions of the Software.
393 *
394 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
395 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
396 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
397 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
398 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
399 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
400 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
401 * SOFTWARE.
402 */
403 static void executeMSD(Element * arr, size_t size, size_t limit)
404 {
405 limit = std::min(limit, size);
406 radixSortMSDInternalHelper<NUM_PASSES - 1>(arr, size, limit);
407 }
408};
409
410
411/// Helper functions for numeric types.
412/// Use RadixSort with custom traits for complex types instead.
413
414template <typename T>
415void radixSortLSD(T *arr, size_t size)
416{
417 RadixSort<RadixSortNumTraits<T>>::executeLSD(arr, size);
418}
419
420template <typename T>
421void radixSortMSD(T *arr, size_t size, size_t limit)
422{
423 RadixSort<RadixSortNumTraits<T>>::executeMSD(arr, size, limit);
424}
425