1#include "duckdb/common/sort/comparators.hpp"
2
3#include "duckdb/common/fast_mem.hpp"
4#include "duckdb/common/sort/sort.hpp"
5
6namespace duckdb {
7
8bool Comparators::TieIsBreakable(const idx_t &tie_col, const data_ptr_t &row_ptr, const SortLayout &sort_layout) {
9 const auto &col_idx = sort_layout.sorting_to_blob_col.at(k: tie_col);
10 // Check if the blob is NULL
11 ValidityBytes row_mask(row_ptr);
12 idx_t entry_idx;
13 idx_t idx_in_entry;
14 ValidityBytes::GetEntryIndex(row_idx: col_idx, entry_idx, idx_in_entry);
15 if (!row_mask.RowIsValid(entry: row_mask.GetValidityEntry(entry_idx), idx_in_entry)) {
16 // Can't break a NULL tie
17 return false;
18 }
19 auto &row_layout = sort_layout.blob_layout;
20 if (row_layout.GetTypes()[col_idx].InternalType() != PhysicalType::VARCHAR) {
21 // Nested type, must be broken
22 return true;
23 }
24 const auto &tie_col_offset = row_layout.GetOffsets()[col_idx];
25 auto tie_string = Load<string_t>(ptr: row_ptr + tie_col_offset);
26 if (tie_string.GetSize() < sort_layout.prefix_lengths[tie_col]) {
27 // No need to break the tie - we already compared the full string
28 return false;
29 }
30 return true;
31}
32
33int Comparators::CompareTuple(const SBScanState &left, const SBScanState &right, const data_ptr_t &l_ptr,
34 const data_ptr_t &r_ptr, const SortLayout &sort_layout, const bool &external_sort) {
35 // Compare the sorting columns one by one
36 int comp_res = 0;
37 data_ptr_t l_ptr_offset = l_ptr;
38 data_ptr_t r_ptr_offset = r_ptr;
39 for (idx_t col_idx = 0; col_idx < sort_layout.column_count; col_idx++) {
40 comp_res = FastMemcmp(str1: l_ptr_offset, str2: r_ptr_offset, size: sort_layout.column_sizes[col_idx]);
41 if (comp_res == 0 && !sort_layout.constant_size[col_idx]) {
42 comp_res = BreakBlobTie(tie_col: col_idx, left, right, sort_layout, external: external_sort);
43 }
44 if (comp_res != 0) {
45 break;
46 }
47 l_ptr_offset += sort_layout.column_sizes[col_idx];
48 r_ptr_offset += sort_layout.column_sizes[col_idx];
49 }
50 return comp_res;
51}
52
53int Comparators::CompareVal(const data_ptr_t l_ptr, const data_ptr_t r_ptr, const LogicalType &type) {
54 switch (type.InternalType()) {
55 case PhysicalType::VARCHAR:
56 return TemplatedCompareVal<string_t>(left_ptr: l_ptr, right_ptr: r_ptr);
57 case PhysicalType::LIST:
58 case PhysicalType::STRUCT: {
59 auto l_nested_ptr = Load<data_ptr_t>(ptr: l_ptr);
60 auto r_nested_ptr = Load<data_ptr_t>(ptr: r_ptr);
61 return CompareValAndAdvance(l_ptr&: l_nested_ptr, r_ptr&: r_nested_ptr, type, valid: true);
62 }
63 default:
64 throw NotImplementedException("Unimplemented CompareVal for type %s", type.ToString());
65 }
66}
67
68int Comparators::BreakBlobTie(const idx_t &tie_col, const SBScanState &left, const SBScanState &right,
69 const SortLayout &sort_layout, const bool &external) {
70 data_ptr_t l_data_ptr = left.DataPtr(sd&: *left.sb->blob_sorting_data);
71 data_ptr_t r_data_ptr = right.DataPtr(sd&: *right.sb->blob_sorting_data);
72 if (!TieIsBreakable(tie_col, row_ptr: l_data_ptr, sort_layout)) {
73 // Quick check to see if ties can be broken
74 return 0;
75 }
76 // Align the pointers
77 const idx_t &col_idx = sort_layout.sorting_to_blob_col.at(k: tie_col);
78 const auto &tie_col_offset = sort_layout.blob_layout.GetOffsets()[col_idx];
79 l_data_ptr += tie_col_offset;
80 r_data_ptr += tie_col_offset;
81 // Do the comparison
82 const int order = sort_layout.order_types[tie_col] == OrderType::DESCENDING ? -1 : 1;
83 const auto &type = sort_layout.blob_layout.GetTypes()[col_idx];
84 int result;
85 if (external) {
86 // Store heap pointers
87 data_ptr_t l_heap_ptr = left.HeapPtr(sd&: *left.sb->blob_sorting_data);
88 data_ptr_t r_heap_ptr = right.HeapPtr(sd&: *right.sb->blob_sorting_data);
89 // Unswizzle offset to pointer
90 UnswizzleSingleValue(data_ptr: l_data_ptr, heap_ptr: l_heap_ptr, type);
91 UnswizzleSingleValue(data_ptr: r_data_ptr, heap_ptr: r_heap_ptr, type);
92 // Compare
93 result = CompareVal(l_ptr: l_data_ptr, r_ptr: r_data_ptr, type);
94 // Swizzle the pointers back to offsets
95 SwizzleSingleValue(data_ptr: l_data_ptr, heap_ptr: l_heap_ptr, type);
96 SwizzleSingleValue(data_ptr: r_data_ptr, heap_ptr: r_heap_ptr, type);
97 } else {
98 result = CompareVal(l_ptr: l_data_ptr, r_ptr: r_data_ptr, type);
99 }
100 return order * result;
101}
102
103template <class T>
104int Comparators::TemplatedCompareVal(const data_ptr_t &left_ptr, const data_ptr_t &right_ptr) {
105 const auto left_val = Load<T>(left_ptr);
106 const auto right_val = Load<T>(right_ptr);
107 if (Equals::Operation<T>(left_val, right_val)) {
108 return 0;
109 } else if (LessThan::Operation<T>(left_val, right_val)) {
110 return -1;
111 } else {
112 return 1;
113 }
114}
115
116int Comparators::CompareValAndAdvance(data_ptr_t &l_ptr, data_ptr_t &r_ptr, const LogicalType &type, bool valid) {
117 switch (type.InternalType()) {
118 case PhysicalType::BOOL:
119 case PhysicalType::INT8:
120 return TemplatedCompareAndAdvance<int8_t>(left_ptr&: l_ptr, right_ptr&: r_ptr);
121 case PhysicalType::INT16:
122 return TemplatedCompareAndAdvance<int16_t>(left_ptr&: l_ptr, right_ptr&: r_ptr);
123 case PhysicalType::INT32:
124 return TemplatedCompareAndAdvance<int32_t>(left_ptr&: l_ptr, right_ptr&: r_ptr);
125 case PhysicalType::INT64:
126 return TemplatedCompareAndAdvance<int64_t>(left_ptr&: l_ptr, right_ptr&: r_ptr);
127 case PhysicalType::UINT8:
128 return TemplatedCompareAndAdvance<uint8_t>(left_ptr&: l_ptr, right_ptr&: r_ptr);
129 case PhysicalType::UINT16:
130 return TemplatedCompareAndAdvance<uint16_t>(left_ptr&: l_ptr, right_ptr&: r_ptr);
131 case PhysicalType::UINT32:
132 return TemplatedCompareAndAdvance<uint32_t>(left_ptr&: l_ptr, right_ptr&: r_ptr);
133 case PhysicalType::UINT64:
134 return TemplatedCompareAndAdvance<uint64_t>(left_ptr&: l_ptr, right_ptr&: r_ptr);
135 case PhysicalType::INT128:
136 return TemplatedCompareAndAdvance<hugeint_t>(left_ptr&: l_ptr, right_ptr&: r_ptr);
137 case PhysicalType::FLOAT:
138 return TemplatedCompareAndAdvance<float>(left_ptr&: l_ptr, right_ptr&: r_ptr);
139 case PhysicalType::DOUBLE:
140 return TemplatedCompareAndAdvance<double>(left_ptr&: l_ptr, right_ptr&: r_ptr);
141 case PhysicalType::INTERVAL:
142 return TemplatedCompareAndAdvance<interval_t>(left_ptr&: l_ptr, right_ptr&: r_ptr);
143 case PhysicalType::VARCHAR:
144 return CompareStringAndAdvance(left_ptr&: l_ptr, right_ptr&: r_ptr, valid);
145 case PhysicalType::LIST:
146 return CompareListAndAdvance(left_ptr&: l_ptr, right_ptr&: r_ptr, type: ListType::GetChildType(type), valid);
147 case PhysicalType::STRUCT:
148 return CompareStructAndAdvance(left_ptr&: l_ptr, right_ptr&: r_ptr, types: StructType::GetChildTypes(type), valid);
149 default:
150 throw NotImplementedException("Unimplemented CompareValAndAdvance for type %s", type.ToString());
151 }
152}
153
154template <class T>
155int Comparators::TemplatedCompareAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr) {
156 auto result = TemplatedCompareVal<T>(left_ptr, right_ptr);
157 left_ptr += sizeof(T);
158 right_ptr += sizeof(T);
159 return result;
160}
161
162int Comparators::CompareStringAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, bool valid) {
163 if (!valid) {
164 return 0;
165 }
166 uint32_t left_string_size = Load<uint32_t>(ptr: left_ptr);
167 uint32_t right_string_size = Load<uint32_t>(ptr: right_ptr);
168 left_ptr += sizeof(uint32_t);
169 right_ptr += sizeof(uint32_t);
170 auto memcmp_res = memcmp(s1: const_char_ptr_cast(src: left_ptr), s2: const_char_ptr_cast(src: right_ptr),
171 n: std::min<uint32_t>(left_string_size, right_string_size));
172
173 left_ptr += left_string_size;
174 right_ptr += right_string_size;
175
176 if (memcmp_res != 0) {
177 return memcmp_res;
178 }
179 if (left_string_size == right_string_size) {
180 return 0;
181 }
182 if (left_string_size < right_string_size) {
183 return -1;
184 }
185 return 1;
186}
187
188int Comparators::CompareStructAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr,
189 const child_list_t<LogicalType> &types, bool valid) {
190 idx_t count = types.size();
191 // Load validity masks
192 ValidityBytes left_validity(left_ptr);
193 ValidityBytes right_validity(right_ptr);
194 left_ptr += (count + 7) / 8;
195 right_ptr += (count + 7) / 8;
196 // Initialize variables
197 bool left_valid;
198 bool right_valid;
199 idx_t entry_idx;
200 idx_t idx_in_entry;
201 // Compare
202 int comp_res = 0;
203 for (idx_t i = 0; i < count; i++) {
204 ValidityBytes::GetEntryIndex(row_idx: i, entry_idx, idx_in_entry);
205 left_valid = left_validity.RowIsValid(entry: left_validity.GetValidityEntry(entry_idx), idx_in_entry);
206 right_valid = right_validity.RowIsValid(entry: right_validity.GetValidityEntry(entry_idx), idx_in_entry);
207 auto &type = types[i].second;
208 if ((left_valid == right_valid) || TypeIsConstantSize(type: type.InternalType())) {
209 comp_res = CompareValAndAdvance(l_ptr&: left_ptr, r_ptr&: right_ptr, type: types[i].second, valid: left_valid && valid);
210 }
211 if (!left_valid && !right_valid) {
212 comp_res = 0;
213 } else if (!left_valid) {
214 comp_res = 1;
215 } else if (!right_valid) {
216 comp_res = -1;
217 }
218 if (comp_res != 0) {
219 break;
220 }
221 }
222 return comp_res;
223}
224
225int Comparators::CompareListAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type,
226 bool valid) {
227 if (!valid) {
228 return 0;
229 }
230 // Load list lengths
231 auto left_len = Load<idx_t>(ptr: left_ptr);
232 auto right_len = Load<idx_t>(ptr: right_ptr);
233 left_ptr += sizeof(idx_t);
234 right_ptr += sizeof(idx_t);
235 // Load list validity masks
236 ValidityBytes left_validity(left_ptr);
237 ValidityBytes right_validity(right_ptr);
238 left_ptr += (left_len + 7) / 8;
239 right_ptr += (right_len + 7) / 8;
240 // Compare
241 int comp_res = 0;
242 idx_t count = MinValue(a: left_len, b: right_len);
243 if (TypeIsConstantSize(type: type.InternalType())) {
244 // Templated code for fixed-size types
245 switch (type.InternalType()) {
246 case PhysicalType::BOOL:
247 case PhysicalType::INT8:
248 comp_res = TemplatedCompareListLoop<int8_t>(left_ptr, right_ptr, left_validity, right_validity, count);
249 break;
250 case PhysicalType::INT16:
251 comp_res = TemplatedCompareListLoop<int16_t>(left_ptr, right_ptr, left_validity, right_validity, count);
252 break;
253 case PhysicalType::INT32:
254 comp_res = TemplatedCompareListLoop<int32_t>(left_ptr, right_ptr, left_validity, right_validity, count);
255 break;
256 case PhysicalType::INT64:
257 comp_res = TemplatedCompareListLoop<int64_t>(left_ptr, right_ptr, left_validity, right_validity, count);
258 break;
259 case PhysicalType::UINT8:
260 comp_res = TemplatedCompareListLoop<uint8_t>(left_ptr, right_ptr, left_validity, right_validity, count);
261 break;
262 case PhysicalType::UINT16:
263 comp_res = TemplatedCompareListLoop<uint16_t>(left_ptr, right_ptr, left_validity, right_validity, count);
264 break;
265 case PhysicalType::UINT32:
266 comp_res = TemplatedCompareListLoop<uint32_t>(left_ptr, right_ptr, left_validity, right_validity, count);
267 break;
268 case PhysicalType::UINT64:
269 comp_res = TemplatedCompareListLoop<uint64_t>(left_ptr, right_ptr, left_validity, right_validity, count);
270 break;
271 case PhysicalType::INT128:
272 comp_res = TemplatedCompareListLoop<hugeint_t>(left_ptr, right_ptr, left_validity, right_validity, count);
273 break;
274 case PhysicalType::FLOAT:
275 comp_res = TemplatedCompareListLoop<float>(left_ptr, right_ptr, left_validity, right_validity, count);
276 break;
277 case PhysicalType::DOUBLE:
278 comp_res = TemplatedCompareListLoop<double>(left_ptr, right_ptr, left_validity, right_validity, count);
279 break;
280 case PhysicalType::INTERVAL:
281 comp_res = TemplatedCompareListLoop<interval_t>(left_ptr, right_ptr, left_validity, right_validity, count);
282 break;
283 default:
284 throw NotImplementedException("CompareListAndAdvance for fixed-size type %s", type.ToString());
285 }
286 } else {
287 // Variable-sized list entries
288 bool left_valid;
289 bool right_valid;
290 idx_t entry_idx;
291 idx_t idx_in_entry;
292 // Size (in bytes) of all variable-sizes entries is stored before the entries begin,
293 // to make deserialization easier. We need to skip over them
294 left_ptr += left_len * sizeof(idx_t);
295 right_ptr += right_len * sizeof(idx_t);
296 for (idx_t i = 0; i < count; i++) {
297 ValidityBytes::GetEntryIndex(row_idx: i, entry_idx, idx_in_entry);
298 left_valid = left_validity.RowIsValid(entry: left_validity.GetValidityEntry(entry_idx), idx_in_entry);
299 right_valid = right_validity.RowIsValid(entry: right_validity.GetValidityEntry(entry_idx), idx_in_entry);
300 if (left_valid && right_valid) {
301 switch (type.InternalType()) {
302 case PhysicalType::LIST:
303 comp_res = CompareListAndAdvance(left_ptr, right_ptr, type: ListType::GetChildType(type), valid: left_valid);
304 break;
305 case PhysicalType::VARCHAR:
306 comp_res = CompareStringAndAdvance(left_ptr, right_ptr, valid: left_valid);
307 break;
308 case PhysicalType::STRUCT:
309 comp_res =
310 CompareStructAndAdvance(left_ptr, right_ptr, types: StructType::GetChildTypes(type), valid: left_valid);
311 break;
312 default:
313 throw NotImplementedException("CompareListAndAdvance for variable-size type %s", type.ToString());
314 }
315 } else if (!left_valid && !right_valid) {
316 comp_res = 0;
317 } else if (left_valid) {
318 comp_res = -1;
319 } else {
320 comp_res = 1;
321 }
322 if (comp_res != 0) {
323 break;
324 }
325 }
326 }
327 // All values that we looped over were equal
328 if (comp_res == 0 && left_len != right_len) {
329 // Smaller lists first
330 if (left_len < right_len) {
331 comp_res = -1;
332 } else {
333 comp_res = 1;
334 }
335 }
336 return comp_res;
337}
338
339template <class T>
340int Comparators::TemplatedCompareListLoop(data_ptr_t &left_ptr, data_ptr_t &right_ptr,
341 const ValidityBytes &left_validity, const ValidityBytes &right_validity,
342 const idx_t &count) {
343 int comp_res = 0;
344 bool left_valid;
345 bool right_valid;
346 idx_t entry_idx;
347 idx_t idx_in_entry;
348 for (idx_t i = 0; i < count; i++) {
349 ValidityBytes::GetEntryIndex(row_idx: i, entry_idx, idx_in_entry);
350 left_valid = left_validity.RowIsValid(entry: left_validity.GetValidityEntry(entry_idx), idx_in_entry);
351 right_valid = right_validity.RowIsValid(entry: right_validity.GetValidityEntry(entry_idx), idx_in_entry);
352 comp_res = TemplatedCompareAndAdvance<T>(left_ptr, right_ptr);
353 if (!left_valid && !right_valid) {
354 comp_res = 0;
355 } else if (!left_valid) {
356 comp_res = 1;
357 } else if (!right_valid) {
358 comp_res = -1;
359 }
360 if (comp_res != 0) {
361 break;
362 }
363 }
364 return comp_res;
365}
366
367void Comparators::UnswizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type) {
368 if (type.InternalType() == PhysicalType::VARCHAR) {
369 data_ptr += string_t::HEADER_SIZE;
370 }
371 Store<data_ptr_t>(val: heap_ptr + Load<idx_t>(ptr: data_ptr), ptr: data_ptr);
372}
373
374void Comparators::SwizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type) {
375 if (type.InternalType() == PhysicalType::VARCHAR) {
376 data_ptr += string_t::HEADER_SIZE;
377 }
378 Store<idx_t>(val: Load<data_ptr_t>(ptr: data_ptr) - heap_ptr, ptr: data_ptr);
379}
380
381} // namespace duckdb
382