| 1 | #pragma once |
| 2 | |
| 3 | |
| 4 | #include <string.h> |
| 5 | #if !defined(__APPLE__) && !defined(__FreeBSD__) |
| 6 | #include <malloc.h> |
| 7 | #endif |
| 8 | #include <algorithm> |
| 9 | #include <cmath> |
| 10 | #include <cstdlib> |
| 11 | #include <cstdint> |
| 12 | #include <type_traits> |
| 13 | |
| 14 | #include <ext/bit_cast.h> |
| 15 | #include <Core/Types.h> |
| 16 | #include <Core/Defines.h> |
| 17 | |
| 18 | |
| 19 | /** Radix sort, has the following functionality: |
| 20 | * Can sort unsigned, signed numbers, and floats. |
| 21 | * Can sort an array of fixed length elements that contain something else besides the key. |
| 22 | * Customizable radix size. |
| 23 | * |
| 24 | * LSB, stable. |
| 25 | * NOTE For some applications it makes sense to add MSB-radix-sort, |
| 26 | * as well as radix-select, radix-partial-sort, radix-get-permutation algorithms based on it. |
| 27 | */ |
| 28 | |
| 29 | |
| 30 | /** Used as a template parameter. See below. |
| 31 | */ |
| 32 | struct RadixSortMallocAllocator |
| 33 | { |
| 34 | void * allocate(size_t size) |
| 35 | { |
| 36 | return malloc(size); |
| 37 | } |
| 38 | |
| 39 | void deallocate(void * ptr, size_t /*size*/) |
| 40 | { |
| 41 | return free(ptr); |
| 42 | } |
| 43 | }; |
| 44 | |
| 45 | |
| 46 | /** A transformation that transforms the bit representation of a key into an unsigned integer number, |
| 47 | * that the order relation over the keys will match the order relation over the obtained unsigned numbers. |
| 48 | * For floats this conversion does the following: |
| 49 | * if the signed bit is set, it flips all other bits. |
| 50 | * In this case, NaN-s are bigger than all normal numbers. |
| 51 | */ |
| 52 | template <typename KeyBits> |
| 53 | struct RadixSortFloatTransform |
| 54 | { |
| 55 | /// Is it worth writing the result in memory, or is it better to do calculation every time again? |
| 56 | static constexpr bool transform_is_simple = false; |
| 57 | |
| 58 | static KeyBits forward(KeyBits x) |
| 59 | { |
| 60 | return x ^ ((-(x >> (sizeof(KeyBits) * 8 - 1))) | (KeyBits(1) << (sizeof(KeyBits) * 8 - 1))); |
| 61 | } |
| 62 | |
| 63 | static KeyBits backward(KeyBits x) |
| 64 | { |
| 65 | return x ^ (((x >> (sizeof(KeyBits) * 8 - 1)) - 1) | (KeyBits(1) << (sizeof(KeyBits) * 8 - 1))); |
| 66 | } |
| 67 | }; |
| 68 | |
| 69 | |
| 70 | template <typename TElement> |
| 71 | struct RadixSortFloatTraits |
| 72 | { |
| 73 | using Element = TElement; /// The type of the element. It can be a structure with a key and some other payload. Or just a key. |
| 74 | using Key = Element; /// The key to sort by. |
| 75 | using CountType = uint32_t; /// Type for calculating histograms. In the case of a known small number of elements, it can be less than size_t. |
| 76 | |
| 77 | /// The type to which the key is transformed to do bit operations. This UInt is the same size as the key. |
| 78 | using KeyBits = std::conditional_t<sizeof(Key) == 8, uint64_t, uint32_t>; |
| 79 | |
| 80 | static constexpr size_t PART_SIZE_BITS = 8; /// With what pieces of the key, in bits, to do one pass - reshuffle of the array. |
| 81 | |
| 82 | /// Converting a key into KeyBits is such that the order relation over the key corresponds to the order relation over KeyBits. |
| 83 | using Transform = RadixSortFloatTransform<KeyBits>; |
| 84 | |
| 85 | /// An object with the functions allocate and deallocate. |
| 86 | /// Can be used, for example, to allocate memory for a temporary array on the stack. |
| 87 | /// To do this, the allocator itself is created on the stack. |
| 88 | using Allocator = RadixSortMallocAllocator; |
| 89 | |
| 90 | /// The function to get the key from an array element. |
| 91 | static Key & (Element & elem) { return elem; } |
| 92 | |
| 93 | /// Used when fallback to comparison based sorting is needed. |
| 94 | /// TODO: Correct handling of NaNs, NULLs, etc |
| 95 | static bool less(Key x, Key y) |
| 96 | { |
| 97 | return x < y; |
| 98 | } |
| 99 | }; |
| 100 | |
| 101 | |
| 102 | template <typename KeyBits> |
| 103 | struct RadixSortIdentityTransform |
| 104 | { |
| 105 | static constexpr bool transform_is_simple = true; |
| 106 | |
| 107 | static KeyBits forward(KeyBits x) { return x; } |
| 108 | static KeyBits backward(KeyBits x) { return x; } |
| 109 | }; |
| 110 | |
| 111 | |
| 112 | |
| 113 | template <typename TElement> |
| 114 | struct RadixSortUIntTraits |
| 115 | { |
| 116 | using Element = TElement; |
| 117 | using Key = Element; |
| 118 | using CountType = uint32_t; |
| 119 | using KeyBits = Key; |
| 120 | |
| 121 | static constexpr size_t PART_SIZE_BITS = 8; |
| 122 | |
| 123 | using Transform = RadixSortIdentityTransform<KeyBits>; |
| 124 | using Allocator = RadixSortMallocAllocator; |
| 125 | |
| 126 | static Key & (Element & elem) { return elem; } |
| 127 | |
| 128 | static bool less(Key x, Key y) |
| 129 | { |
| 130 | return x < y; |
| 131 | } |
| 132 | }; |
| 133 | |
| 134 | |
| 135 | template <typename KeyBits> |
| 136 | struct RadixSortSignedTransform |
| 137 | { |
| 138 | static constexpr bool transform_is_simple = true; |
| 139 | |
| 140 | static KeyBits forward(KeyBits x) { return x ^ (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)); } |
| 141 | static KeyBits backward(KeyBits x) { return x ^ (KeyBits(1) << (sizeof(KeyBits) * 8 - 1)); } |
| 142 | }; |
| 143 | |
| 144 | |
| 145 | template <typename TElement> |
| 146 | struct RadixSortIntTraits |
| 147 | { |
| 148 | using Element = TElement; |
| 149 | using Key = Element; |
| 150 | using CountType = uint32_t; |
| 151 | using KeyBits = std::make_unsigned_t<Key>; |
| 152 | |
| 153 | static constexpr size_t PART_SIZE_BITS = 8; |
| 154 | |
| 155 | using Transform = RadixSortSignedTransform<KeyBits>; |
| 156 | using Allocator = RadixSortMallocAllocator; |
| 157 | |
| 158 | static Key & (Element & elem) { return elem; } |
| 159 | |
| 160 | static bool less(Key x, Key y) |
| 161 | { |
| 162 | return x < y; |
| 163 | } |
| 164 | }; |
| 165 | |
| 166 | |
| 167 | template <typename T> |
| 168 | using RadixSortNumTraits = std::conditional_t< |
| 169 | is_integral_v<T>, |
| 170 | std::conditional_t<is_unsigned_v<T>, RadixSortUIntTraits<T>, RadixSortIntTraits<T>>, |
| 171 | RadixSortFloatTraits<T>>; |
| 172 | |
| 173 | |
| 174 | template <typename Traits> |
| 175 | struct RadixSort |
| 176 | { |
| 177 | private: |
| 178 | using Element = typename Traits::Element; |
| 179 | using Key = typename Traits::Key; |
| 180 | using CountType = typename Traits::CountType; |
| 181 | using KeyBits = typename Traits::KeyBits; |
| 182 | |
| 183 | // Use insertion sort if the size of the array is less than equal to this threshold |
| 184 | static constexpr size_t INSERTION_SORT_THRESHOLD = 64; |
| 185 | |
| 186 | static constexpr size_t HISTOGRAM_SIZE = 1 << Traits::PART_SIZE_BITS; |
| 187 | static constexpr size_t PART_BITMASK = HISTOGRAM_SIZE - 1; |
| 188 | static constexpr size_t KEY_BITS = sizeof(Key) * 8; |
| 189 | static constexpr size_t NUM_PASSES = (KEY_BITS + (Traits::PART_SIZE_BITS - 1)) / Traits::PART_SIZE_BITS; |
| 190 | |
| 191 | static ALWAYS_INLINE KeyBits getPart(size_t N, KeyBits x) |
| 192 | { |
| 193 | if (Traits::Transform::transform_is_simple) |
| 194 | x = Traits::Transform::forward(x); |
| 195 | |
| 196 | return (x >> (N * Traits::PART_SIZE_BITS)) & PART_BITMASK; |
| 197 | } |
| 198 | |
| 199 | static KeyBits keyToBits(Key x) { return ext::bit_cast<KeyBits>(x); } |
| 200 | static Key bitsToKey(KeyBits x) { return ext::bit_cast<Key>(x); } |
| 201 | |
| 202 | static void insertionSortInternal(Element *arr, size_t size) |
| 203 | { |
| 204 | Element * end = arr + size; |
| 205 | for (Element * i = arr + 1; i < end; ++i) |
| 206 | { |
| 207 | if (Traits::less(Traits::extractKey(*i), Traits::extractKey(*(i - 1)))) |
| 208 | { |
| 209 | Element * j; |
| 210 | Element tmp = *i; |
| 211 | *i = *(i - 1); |
| 212 | for (j = i - 1; j > arr && Traits::less(Traits::extractKey(tmp), Traits::extractKey(*(j - 1))); --j) |
| 213 | *j = *(j - 1); |
| 214 | *j = tmp; |
| 215 | } |
| 216 | } |
| 217 | } |
| 218 | |
| 219 | /* Main MSD radix sort subroutine |
| 220 | * Puts elements to buckets based on PASS-th digit, then recursively calls insertion sort or itself on the buckets |
| 221 | */ |
| 222 | template <size_t PASS> |
| 223 | static inline void radixSortMSDInternal(Element * arr, size_t size, size_t limit) |
| 224 | { |
| 225 | Element * last_list[HISTOGRAM_SIZE + 1]; |
| 226 | Element ** last = last_list + 1; |
| 227 | size_t count[HISTOGRAM_SIZE] = {0}; |
| 228 | |
| 229 | for (Element * i = arr; i < arr + size; ++i) |
| 230 | ++count[getPart(PASS, *i)]; |
| 231 | |
| 232 | last_list[0] = last_list[1] = arr; |
| 233 | |
| 234 | size_t buckets_for_recursion = HISTOGRAM_SIZE; |
| 235 | Element * finish = arr + size; |
| 236 | for (size_t i = 1; i < HISTOGRAM_SIZE; ++i) |
| 237 | { |
| 238 | last[i] = last[i - 1] + count[i - 1]; |
| 239 | if (last[i] >= arr + limit) |
| 240 | { |
| 241 | buckets_for_recursion = i; |
| 242 | finish = last[i]; |
| 243 | } |
| 244 | } |
| 245 | |
| 246 | /* At this point, we have the following variables: |
| 247 | * count[i] is the size of i-th bucket |
| 248 | * last[i] is a pointer to the beginning of i-th bucket, last[-1] == last[0] |
| 249 | * buckets_for_recursion is the number of buckets that should be sorted, the last of them only partially |
| 250 | * finish is a pointer to the end of the first buckets_for_recursion buckets |
| 251 | */ |
| 252 | |
| 253 | // Scatter array elements to buckets until the first buckets_for_recursion buckets are full |
| 254 | for (size_t i = 0; i < buckets_for_recursion; ++i) |
| 255 | { |
| 256 | Element * end = last[i - 1] + count[i]; |
| 257 | if (end == finish) |
| 258 | { |
| 259 | last[i] = end; |
| 260 | break; |
| 261 | } |
| 262 | while (last[i] != end) |
| 263 | { |
| 264 | Element swapper = *last[i]; |
| 265 | KeyBits tag = getPart(PASS, swapper); |
| 266 | if (tag != i) |
| 267 | { |
| 268 | do |
| 269 | { |
| 270 | std::swap(swapper, *last[tag]++); |
| 271 | } while ((tag = getPart(PASS, swapper)) != i); |
| 272 | *last[i] = swapper; |
| 273 | } |
| 274 | ++last[i]; |
| 275 | } |
| 276 | } |
| 277 | |
| 278 | if constexpr (PASS > 0) |
| 279 | { |
| 280 | // Recursively sort buckets, except the last one |
| 281 | for (size_t i = 0; i < buckets_for_recursion - 1; ++i) |
| 282 | { |
| 283 | Element * start = last[i - 1]; |
| 284 | size_t subsize = last[i] - last[i - 1]; |
| 285 | radixSortMSDInternalHelper<PASS - 1>(start, subsize, subsize); |
| 286 | } |
| 287 | |
| 288 | // Sort last necessary bucket with limit |
| 289 | Element * start = last[buckets_for_recursion - 2]; |
| 290 | size_t subsize = last[buckets_for_recursion - 1] - last[buckets_for_recursion - 2]; |
| 291 | size_t sublimit = limit - (last[buckets_for_recursion - 1] - arr); |
| 292 | radixSortMSDInternalHelper<PASS - 1>(start, subsize, sublimit); |
| 293 | } |
| 294 | } |
| 295 | |
| 296 | // A helper to choose sorting algorithm based on array length |
| 297 | template <size_t PASS> |
| 298 | static inline void radixSortMSDInternalHelper(Element * arr, size_t size, size_t limit) |
| 299 | { |
| 300 | if (size <= INSERTION_SORT_THRESHOLD) |
| 301 | insertionSortInternal(arr, size); |
| 302 | else |
| 303 | radixSortMSDInternal<PASS>(arr, size, limit); |
| 304 | } |
| 305 | |
| 306 | public: |
| 307 | /// Least significant digit radix sort (stable) |
| 308 | static void executeLSD(Element * arr, size_t size) |
| 309 | { |
| 310 | /// If the array is smaller than 256, then it is better to use another algorithm. |
| 311 | |
| 312 | /// There are loops of NUM_PASSES. It is very important that they are unfolded at compile-time. |
| 313 | |
| 314 | /// For each of the NUM_PASSES bit ranges of the key, consider how many times each value of this bit range met. |
| 315 | CountType histograms[HISTOGRAM_SIZE * NUM_PASSES] = {0}; |
| 316 | |
| 317 | typename Traits::Allocator allocator; |
| 318 | |
| 319 | /// We will do several passes through the array. On each pass, the data is transferred to another array. Let's allocate this temporary array. |
| 320 | Element * swap_buffer = reinterpret_cast<Element *>(allocator.allocate(size * sizeof(Element))); |
| 321 | |
| 322 | /// Transform the array and calculate the histogram. |
| 323 | /// NOTE This is slightly suboptimal. Look at https://github.com/powturbo/TurboHist |
| 324 | for (size_t i = 0; i < size; ++i) |
| 325 | { |
| 326 | if (!Traits::Transform::transform_is_simple) |
| 327 | Traits::extractKey(arr[i]) = bitsToKey(Traits::Transform::forward(keyToBits(Traits::extractKey(arr[i])))); |
| 328 | |
| 329 | for (size_t pass = 0; pass < NUM_PASSES; ++pass) |
| 330 | ++histograms[pass * HISTOGRAM_SIZE + getPart(pass, keyToBits(Traits::extractKey(arr[i])))]; |
| 331 | } |
| 332 | |
| 333 | { |
| 334 | /// Replace the histograms with the accumulated sums: the value in position i is the sum of the previous positions minus one. |
| 335 | size_t sums[NUM_PASSES] = {0}; |
| 336 | |
| 337 | for (size_t i = 0; i < HISTOGRAM_SIZE; ++i) |
| 338 | { |
| 339 | for (size_t pass = 0; pass < NUM_PASSES; ++pass) |
| 340 | { |
| 341 | size_t tmp = histograms[pass * HISTOGRAM_SIZE + i] + sums[pass]; |
| 342 | histograms[pass * HISTOGRAM_SIZE + i] = sums[pass] - 1; |
| 343 | sums[pass] = tmp; |
| 344 | } |
| 345 | } |
| 346 | } |
| 347 | |
| 348 | /// Move the elements in the order starting from the least bit piece, and then do a few passes on the number of pieces. |
| 349 | for (size_t pass = 0; pass < NUM_PASSES; ++pass) |
| 350 | { |
| 351 | Element * writer = pass % 2 ? arr : swap_buffer; |
| 352 | Element * reader = pass % 2 ? swap_buffer : arr; |
| 353 | |
| 354 | for (size_t i = 0; i < size; ++i) |
| 355 | { |
| 356 | size_t pos = getPart(pass, keyToBits(Traits::extractKey(reader[i]))); |
| 357 | |
| 358 | /// Place the element on the next free position. |
| 359 | auto & dest = writer[++histograms[pass * HISTOGRAM_SIZE + pos]]; |
| 360 | dest = reader[i]; |
| 361 | |
| 362 | /// On the last pass, we do the reverse transformation. |
| 363 | if (!Traits::Transform::transform_is_simple && pass == NUM_PASSES - 1) |
| 364 | Traits::extractKey(dest) = bitsToKey(Traits::Transform::backward(keyToBits(Traits::extractKey(reader[i])))); |
| 365 | } |
| 366 | } |
| 367 | |
| 368 | /// If the number of passes is odd, the result array is in a temporary buffer. Copy it to the place of the original array. |
| 369 | /// NOTE Sometimes it will be more optimal to provide non-destructive interface, that will not modify original array. |
| 370 | if (NUM_PASSES % 2) |
| 371 | memcpy(arr, swap_buffer, size * sizeof(Element)); |
| 372 | |
| 373 | allocator.deallocate(swap_buffer, size * sizeof(Element)); |
| 374 | } |
| 375 | |
| 376 | /* Most significant digit radix sort |
| 377 | * Usually slower than LSD and is not stable, but allows partial sorting |
| 378 | * |
| 379 | * Based on https://github.com/voutcn/kxsort, license: |
| 380 | * The MIT License |
| 381 | * Copyright (c) 2016 Dinghua Li <voutcn@gmail.com> |
| 382 | * |
| 383 | * Permission is hereby granted, free of charge, to any person obtaining |
| 384 | * a copy of this software and associated documentation files (the |
| 385 | * "Software"), to deal in the Software without restriction, including |
| 386 | * without limitation the rights to use, copy, modify, merge, publish, |
| 387 | * distribute, sublicense, and/or sell copies of the Software, and to |
| 388 | * permit persons to whom the Software is furnished to do so, subject to |
| 389 | * the following conditions: |
| 390 | * |
| 391 | * The above copyright notice and this permission notice shall be |
| 392 | * included in all copies or substantial portions of the Software. |
| 393 | * |
| 394 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, |
| 395 | * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF |
| 396 | * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND |
| 397 | * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS |
| 398 | * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN |
| 399 | * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN |
| 400 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 401 | * SOFTWARE. |
| 402 | */ |
| 403 | static void executeMSD(Element * arr, size_t size, size_t limit) |
| 404 | { |
| 405 | limit = std::min(limit, size); |
| 406 | radixSortMSDInternalHelper<NUM_PASSES - 1>(arr, size, limit); |
| 407 | } |
| 408 | }; |
| 409 | |
| 410 | |
| 411 | /// Helper functions for numeric types. |
| 412 | /// Use RadixSort with custom traits for complex types instead. |
| 413 | |
| 414 | template <typename T> |
| 415 | void radixSortLSD(T *arr, size_t size) |
| 416 | { |
| 417 | RadixSort<RadixSortNumTraits<T>>::executeLSD(arr, size); |
| 418 | } |
| 419 | |
| 420 | template <typename T> |
| 421 | void radixSortMSD(T *arr, size_t size, size_t limit) |
| 422 | { |
| 423 | RadixSort<RadixSortNumTraits<T>>::executeMSD(arr, size, limit); |
| 424 | } |
| 425 | |