1 | #include "duckdb/execution/join_hashtable.hpp" |
2 | |
3 | #include "duckdb/common/exception.hpp" |
4 | #include "duckdb/common/row_operations/row_operations.hpp" |
5 | #include "duckdb/common/types/column/column_data_collection_segment.hpp" |
6 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
7 | #include "duckdb/main/client_context.hpp" |
8 | #include "duckdb/storage/buffer_manager.hpp" |
9 | |
10 | namespace duckdb { |
11 | |
12 | using ValidityBytes = JoinHashTable::ValidityBytes; |
13 | using ScanStructure = JoinHashTable::ScanStructure; |
14 | using ProbeSpill = JoinHashTable::ProbeSpill; |
15 | using ProbeSpillLocalState = JoinHashTable::ProbeSpillLocalAppendState; |
16 | |
17 | JoinHashTable::JoinHashTable(BufferManager &buffer_manager_p, const vector<JoinCondition> &conditions_p, |
18 | vector<LogicalType> btypes, JoinType type_p) |
19 | : buffer_manager(buffer_manager_p), conditions(conditions_p), build_types(std::move(btypes)), entry_size(0), |
20 | tuple_size(0), vfound(Value::BOOLEAN(value: false)), join_type(type_p), finalized(false), has_null(false), |
21 | external(false), radix_bits(4), partition_start(0), partition_end(0) { |
22 | for (auto &condition : conditions) { |
23 | D_ASSERT(condition.left->return_type == condition.right->return_type); |
24 | auto type = condition.left->return_type; |
25 | if (condition.comparison == ExpressionType::COMPARE_EQUAL || |
26 | condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM || |
27 | condition.comparison == ExpressionType::COMPARE_DISTINCT_FROM) { |
28 | // all equality conditions should be at the front |
29 | // all other conditions at the back |
30 | // this assert checks that |
31 | D_ASSERT(equality_types.size() == condition_types.size()); |
32 | equality_types.push_back(x: type); |
33 | } |
34 | |
35 | predicates.push_back(x: condition.comparison); |
36 | null_values_are_equal.push_back(x: condition.comparison == ExpressionType::COMPARE_DISTINCT_FROM || |
37 | condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM); |
38 | |
39 | condition_types.push_back(x: type); |
40 | } |
41 | // at least one equality is necessary |
42 | D_ASSERT(!equality_types.empty()); |
43 | |
44 | // Types for the layout |
45 | vector<LogicalType> layout_types(condition_types); |
46 | layout_types.insert(position: layout_types.end(), first: build_types.begin(), last: build_types.end()); |
47 | if (IsRightOuterJoin(type: join_type)) { |
48 | // full/right outer joins need an extra bool to keep track of whether or not a tuple has found a matching entry |
49 | // we place the bool before the NEXT pointer |
50 | layout_types.emplace_back(args: LogicalType::BOOLEAN); |
51 | } |
52 | layout_types.emplace_back(args: LogicalType::HASH); |
53 | layout.Initialize(types: layout_types, align: false); |
54 | |
55 | const auto &offsets = layout.GetOffsets(); |
56 | tuple_size = offsets[condition_types.size() + build_types.size()]; |
57 | pointer_offset = offsets.back(); |
58 | entry_size = layout.GetRowWidth(); |
59 | |
60 | data_collection = make_uniq<TupleDataCollection>(args&: buffer_manager, args&: layout); |
61 | sink_collection = |
62 | make_uniq<RadixPartitionedTupleData>(args&: buffer_manager, args&: layout, args&: radix_bits, args: layout.ColumnCount() - 1); |
63 | } |
64 | |
65 | JoinHashTable::~JoinHashTable() { |
66 | } |
67 | |
68 | void JoinHashTable::Merge(JoinHashTable &other) { |
69 | { |
70 | lock_guard<mutex> guard(data_lock); |
71 | data_collection->Combine(other&: *other.data_collection); |
72 | } |
73 | |
74 | if (join_type == JoinType::MARK) { |
75 | auto &info = correlated_mark_join_info; |
76 | lock_guard<mutex> mj_lock(info.mj_lock); |
77 | has_null = has_null || other.has_null; |
78 | if (!info.correlated_types.empty()) { |
79 | auto &other_info = other.correlated_mark_join_info; |
80 | info.correlated_counts->Combine(other&: *other_info.correlated_counts); |
81 | } |
82 | } |
83 | |
84 | sink_collection->Combine(other&: *other.sink_collection); |
85 | } |
86 | |
87 | void JoinHashTable::ApplyBitmask(Vector &hashes, idx_t count) { |
88 | if (hashes.GetVectorType() == VectorType::CONSTANT_VECTOR) { |
89 | D_ASSERT(!ConstantVector::IsNull(hashes)); |
90 | auto indices = ConstantVector::GetData<hash_t>(vector&: hashes); |
91 | *indices = *indices & bitmask; |
92 | } else { |
93 | hashes.Flatten(count); |
94 | auto indices = FlatVector::GetData<hash_t>(vector&: hashes); |
95 | for (idx_t i = 0; i < count; i++) { |
96 | indices[i] &= bitmask; |
97 | } |
98 | } |
99 | } |
100 | |
101 | void JoinHashTable::ApplyBitmask(Vector &hashes, const SelectionVector &sel, idx_t count, Vector &pointers) { |
102 | UnifiedVectorFormat hdata; |
103 | hashes.ToUnifiedFormat(count, data&: hdata); |
104 | |
105 | auto hash_data = UnifiedVectorFormat::GetData<hash_t>(format: hdata); |
106 | auto result_data = FlatVector::GetData<data_ptr_t *>(vector&: pointers); |
107 | auto main_ht = reinterpret_cast<data_ptr_t *>(hash_map.get()); |
108 | for (idx_t i = 0; i < count; i++) { |
109 | auto rindex = sel.get_index(idx: i); |
110 | auto hindex = hdata.sel->get_index(idx: rindex); |
111 | auto hash = hash_data[hindex]; |
112 | result_data[rindex] = main_ht + (hash & bitmask); |
113 | } |
114 | } |
115 | |
116 | void JoinHashTable::Hash(DataChunk &keys, const SelectionVector &sel, idx_t count, Vector &hashes) { |
117 | if (count == keys.size()) { |
118 | // no null values are filtered: use regular hash functions |
119 | VectorOperations::Hash(input&: keys.data[0], hashes, count: keys.size()); |
120 | for (idx_t i = 1; i < equality_types.size(); i++) { |
121 | VectorOperations::CombineHash(hashes, input&: keys.data[i], count: keys.size()); |
122 | } |
123 | } else { |
124 | // null values were filtered: use selection vector |
125 | VectorOperations::Hash(input&: keys.data[0], hashes, rsel: sel, count); |
126 | for (idx_t i = 1; i < equality_types.size(); i++) { |
127 | VectorOperations::CombineHash(hashes, input&: keys.data[i], rsel: sel, count); |
128 | } |
129 | } |
130 | } |
131 | |
132 | static idx_t FilterNullValues(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t count, |
133 | SelectionVector &result) { |
134 | idx_t result_count = 0; |
135 | for (idx_t i = 0; i < count; i++) { |
136 | auto idx = sel.get_index(idx: i); |
137 | auto key_idx = vdata.sel->get_index(idx); |
138 | if (vdata.validity.RowIsValid(row_idx: key_idx)) { |
139 | result.set_index(idx: result_count++, loc: idx); |
140 | } |
141 | } |
142 | return result_count; |
143 | } |
144 | |
145 | idx_t JoinHashTable::PrepareKeys(DataChunk &keys, unsafe_unique_array<UnifiedVectorFormat> &key_data, |
146 | const SelectionVector *¤t_sel, SelectionVector &sel, bool build_side) { |
147 | key_data = keys.ToUnifiedFormat(); |
148 | |
149 | // figure out which keys are NULL, and create a selection vector out of them |
150 | current_sel = FlatVector::IncrementalSelectionVector(); |
151 | idx_t added_count = keys.size(); |
152 | if (build_side && IsRightOuterJoin(type: join_type)) { |
153 | // in case of a right or full outer join, we cannot remove NULL keys from the build side |
154 | return added_count; |
155 | } |
156 | for (idx_t i = 0; i < keys.ColumnCount(); i++) { |
157 | if (!null_values_are_equal[i]) { |
158 | if (key_data[i].validity.AllValid()) { |
159 | continue; |
160 | } |
161 | added_count = FilterNullValues(vdata&: key_data[i], sel: *current_sel, count: added_count, result&: sel); |
162 | // null values are NOT equal for this column, filter them out |
163 | current_sel = &sel; |
164 | } |
165 | } |
166 | return added_count; |
167 | } |
168 | |
169 | void JoinHashTable::Build(PartitionedTupleDataAppendState &append_state, DataChunk &keys, DataChunk &payload) { |
170 | D_ASSERT(!finalized); |
171 | D_ASSERT(keys.size() == payload.size()); |
172 | if (keys.size() == 0) { |
173 | return; |
174 | } |
175 | // special case: correlated mark join |
176 | if (join_type == JoinType::MARK && !correlated_mark_join_info.correlated_types.empty()) { |
177 | auto &info = correlated_mark_join_info; |
178 | lock_guard<mutex> mj_lock(info.mj_lock); |
179 | // Correlated MARK join |
180 | // for the correlated mark join we need to keep track of COUNT(*) and COUNT(COLUMN) for each of the correlated |
181 | // columns push into the aggregate hash table |
182 | D_ASSERT(info.correlated_counts); |
183 | info.group_chunk.SetCardinality(keys); |
184 | for (idx_t i = 0; i < info.correlated_types.size(); i++) { |
185 | info.group_chunk.data[i].Reference(other&: keys.data[i]); |
186 | } |
187 | if (info.correlated_payload.data.empty()) { |
188 | vector<LogicalType> types; |
189 | types.push_back(x: keys.data[info.correlated_types.size()].GetType()); |
190 | info.correlated_payload.InitializeEmpty(types); |
191 | } |
192 | info.correlated_payload.SetCardinality(keys); |
193 | info.correlated_payload.data[0].Reference(other&: keys.data[info.correlated_types.size()]); |
194 | AggregateHTAppendState append_state; |
195 | info.correlated_counts->AddChunk(state&: append_state, groups&: info.group_chunk, payload&: info.correlated_payload, |
196 | filter: AggregateType::NON_DISTINCT); |
197 | } |
198 | |
199 | // prepare the keys for processing |
200 | unsafe_unique_array<UnifiedVectorFormat> key_data; |
201 | const SelectionVector *current_sel; |
202 | SelectionVector sel(STANDARD_VECTOR_SIZE); |
203 | idx_t added_count = PrepareKeys(keys, key_data, current_sel, sel, build_side: true); |
204 | if (added_count < keys.size()) { |
205 | has_null = true; |
206 | } |
207 | if (added_count == 0) { |
208 | return; |
209 | } |
210 | |
211 | // hash the keys and obtain an entry in the list |
212 | // note that we only hash the keys used in the equality comparison |
213 | Vector hash_values(LogicalType::HASH); |
214 | Hash(keys, sel: *current_sel, count: added_count, hashes&: hash_values); |
215 | |
216 | // build a chunk to append to the data collection [keys, payload, (optional "found" boolean), hash] |
217 | DataChunk source_chunk; |
218 | source_chunk.InitializeEmpty(types: layout.GetTypes()); |
219 | for (idx_t i = 0; i < keys.ColumnCount(); i++) { |
220 | source_chunk.data[i].Reference(other&: keys.data[i]); |
221 | } |
222 | idx_t col_offset = keys.ColumnCount(); |
223 | D_ASSERT(build_types.size() == payload.ColumnCount()); |
224 | for (idx_t i = 0; i < payload.ColumnCount(); i++) { |
225 | source_chunk.data[col_offset + i].Reference(other&: payload.data[i]); |
226 | } |
227 | col_offset += payload.ColumnCount(); |
228 | if (IsRightOuterJoin(type: join_type)) { |
229 | // for FULL/RIGHT OUTER joins initialize the "found" boolean to false |
230 | source_chunk.data[col_offset].Reference(other&: vfound); |
231 | col_offset++; |
232 | } |
233 | source_chunk.data[col_offset].Reference(other&: hash_values); |
234 | source_chunk.SetCardinality(keys); |
235 | |
236 | if (added_count < keys.size()) { |
237 | source_chunk.Slice(sel_vector: *current_sel, count: added_count); |
238 | } |
239 | sink_collection->Append(state&: append_state, input&: source_chunk); |
240 | } |
241 | |
242 | template <bool PARALLEL> |
243 | static inline void InsertHashesLoop(atomic<data_ptr_t> pointers[], const hash_t indices[], const idx_t count, |
244 | const data_ptr_t key_locations[], const idx_t pointer_offset) { |
245 | for (idx_t i = 0; i < count; i++) { |
246 | const auto index = indices[i]; |
247 | if (PARALLEL) { |
248 | data_ptr_t head; |
249 | do { |
250 | head = pointers[index]; |
251 | Store<data_ptr_t>(val: head, ptr: key_locations[i] + pointer_offset); |
252 | } while (!std::atomic_compare_exchange_weak(a: &pointers[index], i1: &head, i2: key_locations[i])); |
253 | } else { |
254 | // set prev in current key to the value (NOTE: this will be nullptr if there is none) |
255 | Store<data_ptr_t>(val: pointers[index], ptr: key_locations[i] + pointer_offset); |
256 | |
257 | // set pointer to current tuple |
258 | pointers[index] = key_locations[i]; |
259 | } |
260 | } |
261 | } |
262 | |
263 | void JoinHashTable::InsertHashes(Vector &hashes, idx_t count, data_ptr_t key_locations[], bool parallel) { |
264 | D_ASSERT(hashes.GetType().id() == LogicalType::HASH); |
265 | |
266 | // use bitmask to get position in array |
267 | ApplyBitmask(hashes, count); |
268 | |
269 | hashes.Flatten(count); |
270 | D_ASSERT(hashes.GetVectorType() == VectorType::FLAT_VECTOR); |
271 | |
272 | auto pointers = reinterpret_cast<atomic<data_ptr_t> *>(hash_map.get()); |
273 | auto indices = FlatVector::GetData<hash_t>(vector&: hashes); |
274 | |
275 | if (parallel) { |
276 | InsertHashesLoop<true>(pointers, indices, count, key_locations, pointer_offset); |
277 | } else { |
278 | InsertHashesLoop<false>(pointers, indices, count, key_locations, pointer_offset); |
279 | } |
280 | } |
281 | |
282 | void JoinHashTable::InitializePointerTable() { |
283 | idx_t capacity = PointerTableCapacity(count: Count()); |
284 | D_ASSERT(IsPowerOfTwo(capacity)); |
285 | |
286 | if (hash_map.get()) { |
287 | // There is already a hash map |
288 | auto current_capacity = hash_map.GetSize() / sizeof(data_ptr_t); |
289 | if (capacity > current_capacity) { |
290 | // Need more space |
291 | hash_map = buffer_manager.GetBufferAllocator().Allocate(size: capacity * sizeof(data_ptr_t)); |
292 | } else { |
293 | // Just use the current hash map |
294 | capacity = current_capacity; |
295 | } |
296 | } else { |
297 | // Allocate a hash map |
298 | hash_map = buffer_manager.GetBufferAllocator().Allocate(size: capacity * sizeof(data_ptr_t)); |
299 | } |
300 | D_ASSERT(hash_map.GetSize() == capacity * sizeof(data_ptr_t)); |
301 | |
302 | // initialize HT with all-zero entries |
303 | std::fill_n(reinterpret_cast<data_ptr_t *>(hash_map.get()), capacity, nullptr); |
304 | |
305 | bitmask = capacity - 1; |
306 | } |
307 | |
308 | void JoinHashTable::Finalize(idx_t chunk_idx_from, idx_t chunk_idx_to, bool parallel) { |
309 | // Pointer table should be allocated |
310 | D_ASSERT(hash_map.get()); |
311 | |
312 | Vector hashes(LogicalType::HASH); |
313 | auto hash_data = FlatVector::GetData<hash_t>(vector&: hashes); |
314 | |
315 | TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::KEEP_EVERYTHING_PINNED, chunk_idx_from, |
316 | chunk_idx_to, false); |
317 | const auto row_locations = iterator.GetRowLocations(); |
318 | do { |
319 | const auto count = iterator.GetCurrentChunkCount(); |
320 | for (idx_t i = 0; i < count; i++) { |
321 | hash_data[i] = Load<hash_t>(ptr: row_locations[i] + pointer_offset); |
322 | } |
323 | InsertHashes(hashes, count, key_locations: row_locations, parallel); |
324 | } while (iterator.Next()); |
325 | } |
326 | |
327 | unique_ptr<ScanStructure> JoinHashTable::InitializeScanStructure(DataChunk &keys, const SelectionVector *¤t_sel) { |
328 | D_ASSERT(Count() > 0); // should be handled before |
329 | D_ASSERT(finalized); |
330 | |
331 | // set up the scan structure |
332 | auto ss = make_uniq<ScanStructure>(args&: *this); |
333 | |
334 | if (join_type != JoinType::INNER) { |
335 | ss->found_match = make_unsafe_uniq_array<bool>(STANDARD_VECTOR_SIZE); |
336 | memset(s: ss->found_match.get(), c: 0, n: sizeof(bool) * STANDARD_VECTOR_SIZE); |
337 | } |
338 | |
339 | // first prepare the keys for probing |
340 | ss->count = PrepareKeys(keys, key_data&: ss->key_data, current_sel, sel&: ss->sel_vector, build_side: false); |
341 | return ss; |
342 | } |
343 | |
344 | unique_ptr<ScanStructure> JoinHashTable::Probe(DataChunk &keys, Vector *precomputed_hashes) { |
345 | const SelectionVector *current_sel; |
346 | auto ss = InitializeScanStructure(keys, current_sel); |
347 | if (ss->count == 0) { |
348 | return ss; |
349 | } |
350 | |
351 | if (precomputed_hashes) { |
352 | ApplyBitmask(hashes&: *precomputed_hashes, sel: *current_sel, count: ss->count, pointers&: ss->pointers); |
353 | } else { |
354 | // hash all the keys |
355 | Vector hashes(LogicalType::HASH); |
356 | Hash(keys, sel: *current_sel, count: ss->count, hashes); |
357 | |
358 | // now initialize the pointers of the scan structure based on the hashes |
359 | ApplyBitmask(hashes, sel: *current_sel, count: ss->count, pointers&: ss->pointers); |
360 | } |
361 | |
362 | // create the selection vector linking to only non-empty entries |
363 | ss->InitializeSelectionVector(current_sel); |
364 | |
365 | return ss; |
366 | } |
367 | |
368 | ScanStructure::ScanStructure(JoinHashTable &ht) |
369 | : pointers(LogicalType::POINTER), sel_vector(STANDARD_VECTOR_SIZE), ht(ht), finished(false) { |
370 | } |
371 | |
372 | void ScanStructure::Next(DataChunk &keys, DataChunk &left, DataChunk &result) { |
373 | if (finished) { |
374 | return; |
375 | } |
376 | switch (ht.join_type) { |
377 | case JoinType::INNER: |
378 | case JoinType::RIGHT: |
379 | NextInnerJoin(keys, left, result); |
380 | break; |
381 | case JoinType::SEMI: |
382 | NextSemiJoin(keys, left, result); |
383 | break; |
384 | case JoinType::MARK: |
385 | NextMarkJoin(keys, left, result); |
386 | break; |
387 | case JoinType::ANTI: |
388 | NextAntiJoin(keys, left, result); |
389 | break; |
390 | case JoinType::OUTER: |
391 | case JoinType::LEFT: |
392 | NextLeftJoin(keys, left, result); |
393 | break; |
394 | case JoinType::SINGLE: |
395 | NextSingleJoin(keys, left, result); |
396 | break; |
397 | default: |
398 | throw InternalException("Unhandled join type in JoinHashTable" ); |
399 | } |
400 | } |
401 | |
402 | idx_t ScanStructure::ResolvePredicates(DataChunk &keys, SelectionVector &match_sel, SelectionVector *no_match_sel) { |
403 | // Start with the scan selection |
404 | for (idx_t i = 0; i < this->count; ++i) { |
405 | match_sel.set_index(idx: i, loc: this->sel_vector.get_index(idx: i)); |
406 | } |
407 | idx_t no_match_count = 0; |
408 | |
409 | return RowOperations::Match(columns&: keys, col_data: key_data.get(), layout: ht.layout, rows&: pointers, predicates: ht.predicates, sel&: match_sel, count: this->count, |
410 | no_match: no_match_sel, no_match_count); |
411 | } |
412 | |
413 | idx_t ScanStructure::ScanInnerJoin(DataChunk &keys, SelectionVector &result_vector) { |
414 | while (true) { |
415 | // resolve the predicates for this set of keys |
416 | idx_t result_count = ResolvePredicates(keys, match_sel&: result_vector, no_match_sel: nullptr); |
417 | |
418 | // after doing all the comparisons set the found_match vector |
419 | if (found_match) { |
420 | for (idx_t i = 0; i < result_count; i++) { |
421 | auto idx = result_vector.get_index(idx: i); |
422 | found_match[idx] = true; |
423 | } |
424 | } |
425 | if (result_count > 0) { |
426 | return result_count; |
427 | } |
428 | // no matches found: check the next set of pointers |
429 | AdvancePointers(); |
430 | if (this->count == 0) { |
431 | return 0; |
432 | } |
433 | } |
434 | } |
435 | |
436 | void ScanStructure::AdvancePointers(const SelectionVector &sel, idx_t sel_count) { |
437 | // now for all the pointers, we move on to the next set of pointers |
438 | idx_t new_count = 0; |
439 | auto ptrs = FlatVector::GetData<data_ptr_t>(vector&: this->pointers); |
440 | for (idx_t i = 0; i < sel_count; i++) { |
441 | auto idx = sel.get_index(idx: i); |
442 | ptrs[idx] = Load<data_ptr_t>(ptr: ptrs[idx] + ht.pointer_offset); |
443 | if (ptrs[idx]) { |
444 | this->sel_vector.set_index(idx: new_count++, loc: idx); |
445 | } |
446 | } |
447 | this->count = new_count; |
448 | } |
449 | |
450 | void ScanStructure::InitializeSelectionVector(const SelectionVector *¤t_sel) { |
451 | idx_t non_empty_count = 0; |
452 | auto ptrs = FlatVector::GetData<data_ptr_t>(vector&: pointers); |
453 | auto cnt = count; |
454 | for (idx_t i = 0; i < cnt; i++) { |
455 | const auto idx = current_sel->get_index(idx: i); |
456 | ptrs[idx] = Load<data_ptr_t>(ptr: ptrs[idx]); |
457 | if (ptrs[idx]) { |
458 | sel_vector.set_index(idx: non_empty_count++, loc: idx); |
459 | } |
460 | } |
461 | count = non_empty_count; |
462 | } |
463 | |
464 | void ScanStructure::AdvancePointers() { |
465 | AdvancePointers(sel: this->sel_vector, sel_count: this->count); |
466 | } |
467 | |
468 | void ScanStructure::GatherResult(Vector &result, const SelectionVector &result_vector, |
469 | const SelectionVector &sel_vector, const idx_t count, const idx_t col_no) { |
470 | ht.data_collection->Gather(row_locations&: pointers, sel: sel_vector, scan_count: count, column_id: col_no, result, target_sel: result_vector); |
471 | } |
472 | |
473 | void ScanStructure::GatherResult(Vector &result, const SelectionVector &sel_vector, const idx_t count, |
474 | const idx_t col_idx) { |
475 | GatherResult(result, result_vector: *FlatVector::IncrementalSelectionVector(), sel_vector, count, col_no: col_idx); |
476 | } |
477 | |
478 | void ScanStructure::NextInnerJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { |
479 | D_ASSERT(result.ColumnCount() == left.ColumnCount() + ht.build_types.size()); |
480 | if (this->count == 0) { |
481 | // no pointers left to chase |
482 | return; |
483 | } |
484 | |
485 | SelectionVector result_vector(STANDARD_VECTOR_SIZE); |
486 | |
487 | idx_t result_count = ScanInnerJoin(keys, result_vector); |
488 | if (result_count > 0) { |
489 | if (IsRightOuterJoin(type: ht.join_type)) { |
490 | // full/right outer join: mark join matches as FOUND in the HT |
491 | auto ptrs = FlatVector::GetData<data_ptr_t>(vector&: pointers); |
492 | for (idx_t i = 0; i < result_count; i++) { |
493 | auto idx = result_vector.get_index(idx: i); |
494 | // NOTE: threadsan reports this as a data race because this can be set concurrently by separate threads |
495 | // Technically it is, but it does not matter, since the only value that can be written is "true" |
496 | Store<bool>(val: true, ptr: ptrs[idx] + ht.tuple_size); |
497 | } |
498 | } |
499 | // matches were found |
500 | // construct the result |
501 | // on the LHS, we create a slice using the result vector |
502 | result.Slice(other&: left, sel: result_vector, count: result_count); |
503 | |
504 | // on the RHS, we need to fetch the data from the hash table |
505 | for (idx_t i = 0; i < ht.build_types.size(); i++) { |
506 | auto &vector = result.data[left.ColumnCount() + i]; |
507 | D_ASSERT(vector.GetType() == ht.build_types[i]); |
508 | GatherResult(result&: vector, sel_vector: result_vector, count: result_count, col_idx: i + ht.condition_types.size()); |
509 | } |
510 | AdvancePointers(); |
511 | } |
512 | } |
513 | |
514 | void ScanStructure::ScanKeyMatches(DataChunk &keys) { |
515 | // the semi-join, anti-join and mark-join we handle a differently from the inner join |
516 | // since there can be at most STANDARD_VECTOR_SIZE results |
517 | // we handle the entire chunk in one call to Next(). |
518 | // for every pointer, we keep chasing pointers and doing comparisons. |
519 | // this results in a boolean array indicating whether or not the tuple has a match |
520 | SelectionVector match_sel(STANDARD_VECTOR_SIZE), no_match_sel(STANDARD_VECTOR_SIZE); |
521 | while (this->count > 0) { |
522 | // resolve the predicates for the current set of pointers |
523 | idx_t match_count = ResolvePredicates(keys, match_sel, no_match_sel: &no_match_sel); |
524 | idx_t no_match_count = this->count - match_count; |
525 | |
526 | // mark each of the matches as found |
527 | for (idx_t i = 0; i < match_count; i++) { |
528 | found_match[match_sel.get_index(idx: i)] = true; |
529 | } |
530 | // continue searching for the ones where we did not find a match yet |
531 | AdvancePointers(sel: no_match_sel, sel_count: no_match_count); |
532 | } |
533 | } |
534 | |
535 | template <bool MATCH> |
536 | void ScanStructure::NextSemiOrAntiJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { |
537 | D_ASSERT(left.ColumnCount() == result.ColumnCount()); |
538 | D_ASSERT(keys.size() == left.size()); |
539 | // create the selection vector from the matches that were found |
540 | SelectionVector sel(STANDARD_VECTOR_SIZE); |
541 | idx_t result_count = 0; |
542 | for (idx_t i = 0; i < keys.size(); i++) { |
543 | if (found_match[i] == MATCH) { |
544 | // part of the result |
545 | sel.set_index(idx: result_count++, loc: i); |
546 | } |
547 | } |
548 | // construct the final result |
549 | if (result_count > 0) { |
550 | // we only return the columns on the left side |
551 | // reference the columns of the left side from the result |
552 | result.Slice(other&: left, sel, count: result_count); |
553 | } else { |
554 | D_ASSERT(result.size() == 0); |
555 | } |
556 | } |
557 | |
558 | void ScanStructure::NextSemiJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { |
559 | // first scan for key matches |
560 | ScanKeyMatches(keys); |
561 | // then construct the result from all tuples with a match |
562 | NextSemiOrAntiJoin<true>(keys, left, result); |
563 | |
564 | finished = true; |
565 | } |
566 | |
567 | void ScanStructure::NextAntiJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { |
568 | // first scan for key matches |
569 | ScanKeyMatches(keys); |
570 | // then construct the result from all tuples that did not find a match |
571 | NextSemiOrAntiJoin<false>(keys, left, result); |
572 | |
573 | finished = true; |
574 | } |
575 | |
576 | void ScanStructure::ConstructMarkJoinResult(DataChunk &join_keys, DataChunk &child, DataChunk &result) { |
577 | // for the initial set of columns we just reference the left side |
578 | result.SetCardinality(child); |
579 | for (idx_t i = 0; i < child.ColumnCount(); i++) { |
580 | result.data[i].Reference(other&: child.data[i]); |
581 | } |
582 | auto &mark_vector = result.data.back(); |
583 | mark_vector.SetVectorType(VectorType::FLAT_VECTOR); |
584 | // first we set the NULL values from the join keys |
585 | // if there is any NULL in the keys, the result is NULL |
586 | auto bool_result = FlatVector::GetData<bool>(vector&: mark_vector); |
587 | auto &mask = FlatVector::Validity(vector&: mark_vector); |
588 | for (idx_t col_idx = 0; col_idx < join_keys.ColumnCount(); col_idx++) { |
589 | if (ht.null_values_are_equal[col_idx]) { |
590 | continue; |
591 | } |
592 | UnifiedVectorFormat jdata; |
593 | join_keys.data[col_idx].ToUnifiedFormat(count: join_keys.size(), data&: jdata); |
594 | if (!jdata.validity.AllValid()) { |
595 | for (idx_t i = 0; i < join_keys.size(); i++) { |
596 | auto jidx = jdata.sel->get_index(idx: i); |
597 | mask.Set(row_idx: i, valid: jdata.validity.RowIsValidUnsafe(row_idx: jidx)); |
598 | } |
599 | } |
600 | } |
601 | // now set the remaining entries to either true or false based on whether a match was found |
602 | if (found_match) { |
603 | for (idx_t i = 0; i < child.size(); i++) { |
604 | bool_result[i] = found_match[i]; |
605 | } |
606 | } else { |
607 | memset(s: bool_result, c: 0, n: sizeof(bool) * child.size()); |
608 | } |
609 | // if the right side contains NULL values, the result of any FALSE becomes NULL |
610 | if (ht.has_null) { |
611 | for (idx_t i = 0; i < child.size(); i++) { |
612 | if (!bool_result[i]) { |
613 | mask.SetInvalid(i); |
614 | } |
615 | } |
616 | } |
617 | } |
618 | |
619 | void ScanStructure::NextMarkJoin(DataChunk &keys, DataChunk &input, DataChunk &result) { |
620 | D_ASSERT(result.ColumnCount() == input.ColumnCount() + 1); |
621 | D_ASSERT(result.data.back().GetType() == LogicalType::BOOLEAN); |
622 | // this method should only be called for a non-empty HT |
623 | D_ASSERT(ht.Count() > 0); |
624 | |
625 | ScanKeyMatches(keys); |
626 | if (ht.correlated_mark_join_info.correlated_types.empty()) { |
627 | ConstructMarkJoinResult(join_keys&: keys, child&: input, result); |
628 | } else { |
629 | auto &info = ht.correlated_mark_join_info; |
630 | lock_guard<mutex> mj_lock(info.mj_lock); |
631 | |
632 | // there are correlated columns |
633 | // first we fetch the counts from the aggregate hashtable corresponding to these entries |
634 | D_ASSERT(keys.ColumnCount() == info.group_chunk.ColumnCount() + 1); |
635 | info.group_chunk.SetCardinality(keys); |
636 | for (idx_t i = 0; i < info.group_chunk.ColumnCount(); i++) { |
637 | info.group_chunk.data[i].Reference(other&: keys.data[i]); |
638 | } |
639 | info.correlated_counts->FetchAggregates(groups&: info.group_chunk, result&: info.result_chunk); |
640 | |
641 | // for the initial set of columns we just reference the left side |
642 | result.SetCardinality(input); |
643 | for (idx_t i = 0; i < input.ColumnCount(); i++) { |
644 | result.data[i].Reference(other&: input.data[i]); |
645 | } |
646 | // create the result matching vector |
647 | auto &last_key = keys.data.back(); |
648 | auto &result_vector = result.data.back(); |
649 | // first set the nullmask based on whether or not there were NULL values in the join key |
650 | result_vector.SetVectorType(VectorType::FLAT_VECTOR); |
651 | auto bool_result = FlatVector::GetData<bool>(vector&: result_vector); |
652 | auto &mask = FlatVector::Validity(vector&: result_vector); |
653 | switch (last_key.GetVectorType()) { |
654 | case VectorType::CONSTANT_VECTOR: |
655 | if (ConstantVector::IsNull(vector: last_key)) { |
656 | mask.SetAllInvalid(input.size()); |
657 | } |
658 | break; |
659 | case VectorType::FLAT_VECTOR: |
660 | mask.Copy(other: FlatVector::Validity(vector&: last_key), count: input.size()); |
661 | break; |
662 | default: { |
663 | UnifiedVectorFormat kdata; |
664 | last_key.ToUnifiedFormat(count: keys.size(), data&: kdata); |
665 | for (idx_t i = 0; i < input.size(); i++) { |
666 | auto kidx = kdata.sel->get_index(idx: i); |
667 | mask.Set(row_idx: i, valid: kdata.validity.RowIsValid(row_idx: kidx)); |
668 | } |
669 | break; |
670 | } |
671 | } |
672 | |
673 | auto count_star = FlatVector::GetData<int64_t>(vector&: info.result_chunk.data[0]); |
674 | auto count = FlatVector::GetData<int64_t>(vector&: info.result_chunk.data[1]); |
675 | // set the entries to either true or false based on whether a match was found |
676 | for (idx_t i = 0; i < input.size(); i++) { |
677 | D_ASSERT(count_star[i] >= count[i]); |
678 | bool_result[i] = found_match ? found_match[i] : false; |
679 | if (!bool_result[i] && count_star[i] > count[i]) { |
680 | // RHS has NULL value and result is false: set to null |
681 | mask.SetInvalid(i); |
682 | } |
683 | if (count_star[i] == 0) { |
684 | // count == 0, set nullmask to false (we know the result is false now) |
685 | mask.SetValid(i); |
686 | } |
687 | } |
688 | } |
689 | finished = true; |
690 | } |
691 | |
692 | void ScanStructure::NextLeftJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { |
693 | // a LEFT OUTER JOIN is identical to an INNER JOIN except all tuples that do |
694 | // not have a match must return at least one tuple (with the right side set |
695 | // to NULL in every column) |
696 | NextInnerJoin(keys, left, result); |
697 | if (result.size() == 0) { |
698 | // no entries left from the normal join |
699 | // fill in the result of the remaining left tuples |
700 | // together with NULL values on the right-hand side |
701 | idx_t remaining_count = 0; |
702 | SelectionVector sel(STANDARD_VECTOR_SIZE); |
703 | for (idx_t i = 0; i < left.size(); i++) { |
704 | if (!found_match[i]) { |
705 | sel.set_index(idx: remaining_count++, loc: i); |
706 | } |
707 | } |
708 | if (remaining_count > 0) { |
709 | // have remaining tuples |
710 | // slice the left side with tuples that did not find a match |
711 | result.Slice(other&: left, sel, count: remaining_count); |
712 | |
713 | // now set the right side to NULL |
714 | for (idx_t i = left.ColumnCount(); i < result.ColumnCount(); i++) { |
715 | Vector &vec = result.data[i]; |
716 | vec.SetVectorType(VectorType::CONSTANT_VECTOR); |
717 | ConstantVector::SetNull(vector&: vec, is_null: true); |
718 | } |
719 | } |
720 | finished = true; |
721 | } |
722 | } |
723 | |
724 | void ScanStructure::NextSingleJoin(DataChunk &keys, DataChunk &input, DataChunk &result) { |
725 | // single join |
726 | // this join is similar to the semi join except that |
727 | // (1) we actually return data from the RHS and |
728 | // (2) we return NULL for that data if there is no match |
729 | idx_t result_count = 0; |
730 | SelectionVector result_sel(STANDARD_VECTOR_SIZE); |
731 | SelectionVector match_sel(STANDARD_VECTOR_SIZE), no_match_sel(STANDARD_VECTOR_SIZE); |
732 | while (this->count > 0) { |
733 | // resolve the predicates for the current set of pointers |
734 | idx_t match_count = ResolvePredicates(keys, match_sel, no_match_sel: &no_match_sel); |
735 | idx_t no_match_count = this->count - match_count; |
736 | |
737 | // mark each of the matches as found |
738 | for (idx_t i = 0; i < match_count; i++) { |
739 | // found a match for this index |
740 | auto index = match_sel.get_index(idx: i); |
741 | found_match[index] = true; |
742 | result_sel.set_index(idx: result_count++, loc: index); |
743 | } |
744 | // continue searching for the ones where we did not find a match yet |
745 | AdvancePointers(sel: no_match_sel, sel_count: no_match_count); |
746 | } |
747 | // reference the columns of the left side from the result |
748 | D_ASSERT(input.ColumnCount() > 0); |
749 | for (idx_t i = 0; i < input.ColumnCount(); i++) { |
750 | result.data[i].Reference(other&: input.data[i]); |
751 | } |
752 | // now fetch the data from the RHS |
753 | for (idx_t i = 0; i < ht.build_types.size(); i++) { |
754 | auto &vector = result.data[input.ColumnCount() + i]; |
755 | // set NULL entries for every entry that was not found |
756 | for (idx_t j = 0; j < input.size(); j++) { |
757 | if (!found_match[j]) { |
758 | FlatVector::SetNull(vector, idx: j, is_null: true); |
759 | } |
760 | } |
761 | // for the remaining values we fetch the values |
762 | GatherResult(result&: vector, result_vector: result_sel, sel_vector: result_sel, count: result_count, col_no: i + ht.condition_types.size()); |
763 | } |
764 | result.SetCardinality(input.size()); |
765 | |
766 | // like the SEMI, ANTI and MARK join types, the SINGLE join only ever does one pass over the HT per input chunk |
767 | finished = true; |
768 | } |
769 | |
770 | void JoinHashTable::ScanFullOuter(JoinHTScanState &state, Vector &addresses, DataChunk &result) { |
771 | // scan the HT starting from the current position and check which rows from the build side did not find a match |
772 | auto key_locations = FlatVector::GetData<data_ptr_t>(vector&: addresses); |
773 | idx_t found_entries = 0; |
774 | |
775 | auto &iterator = state.iterator; |
776 | if (iterator.Done()) { |
777 | return; |
778 | } |
779 | |
780 | const auto row_locations = iterator.GetRowLocations(); |
781 | do { |
782 | const auto count = iterator.GetCurrentChunkCount(); |
783 | for (idx_t i = state.offset_in_chunk; i < count; i++) { |
784 | auto found_match = Load<bool>(ptr: row_locations[i] + tuple_size); |
785 | if (!found_match) { |
786 | key_locations[found_entries++] = row_locations[i]; |
787 | if (found_entries == STANDARD_VECTOR_SIZE) { |
788 | state.offset_in_chunk = i + 1; |
789 | break; |
790 | } |
791 | } |
792 | } |
793 | if (found_entries == STANDARD_VECTOR_SIZE) { |
794 | break; |
795 | } |
796 | state.offset_in_chunk = 0; |
797 | } while (iterator.Next()); |
798 | |
799 | // now gather from the found rows |
800 | if (found_entries == 0) { |
801 | return; |
802 | } |
803 | result.SetCardinality(found_entries); |
804 | idx_t left_column_count = result.ColumnCount() - build_types.size(); |
805 | const auto &sel_vector = *FlatVector::IncrementalSelectionVector(); |
806 | // set the left side as a constant NULL |
807 | for (idx_t i = 0; i < left_column_count; i++) { |
808 | Vector &vec = result.data[i]; |
809 | vec.SetVectorType(VectorType::CONSTANT_VECTOR); |
810 | ConstantVector::SetNull(vector&: vec, is_null: true); |
811 | } |
812 | |
813 | // gather the values from the RHS |
814 | for (idx_t i = 0; i < build_types.size(); i++) { |
815 | auto &vector = result.data[left_column_count + i]; |
816 | D_ASSERT(vector.GetType() == build_types[i]); |
817 | const auto col_no = condition_types.size() + i; |
818 | data_collection->Gather(row_locations&: addresses, sel: sel_vector, scan_count: found_entries, column_id: col_no, result&: vector, target_sel: sel_vector); |
819 | } |
820 | } |
821 | |
822 | idx_t JoinHashTable::FillWithHTOffsets(JoinHTScanState &state, Vector &addresses) { |
823 | // iterate over HT |
824 | auto key_locations = FlatVector::GetData<data_ptr_t>(vector&: addresses); |
825 | idx_t key_count = 0; |
826 | |
827 | auto &iterator = state.iterator; |
828 | const auto row_locations = iterator.GetRowLocations(); |
829 | do { |
830 | const auto count = iterator.GetCurrentChunkCount(); |
831 | for (idx_t i = 0; i < count; i++) { |
832 | key_locations[key_count + i] = row_locations[i]; |
833 | } |
834 | key_count += count; |
835 | } while (iterator.Next()); |
836 | |
837 | return key_count; |
838 | } |
839 | |
840 | bool JoinHashTable::RequiresExternalJoin(ClientConfig &config, vector<unique_ptr<JoinHashTable>> &local_hts) { |
841 | total_count = 0; |
842 | idx_t data_size = 0; |
843 | for (auto &ht : local_hts) { |
844 | auto &local_sink_collection = ht->GetSinkCollection(); |
845 | total_count += local_sink_collection.Count(); |
846 | data_size += local_sink_collection.SizeInBytes(); |
847 | } |
848 | |
849 | if (total_count == 0) { |
850 | return false; |
851 | } |
852 | |
853 | if (config.force_external) { |
854 | // Do ~3 rounds if forcing external join to test all code paths |
855 | auto data_size_per_round = (data_size + 2) / 3; |
856 | auto count_per_round = (total_count + 2) / 3; |
857 | max_ht_size = data_size_per_round + PointerTableSize(count: count_per_round); |
858 | external = true; |
859 | } else { |
860 | auto ht_size = data_size + PointerTableSize(count: total_count); |
861 | external = ht_size > max_ht_size; |
862 | } |
863 | return external; |
864 | } |
865 | |
866 | void JoinHashTable::Unpartition() { |
867 | for (auto &partition : sink_collection->GetPartitions()) { |
868 | data_collection->Combine(other&: *partition); |
869 | } |
870 | } |
871 | |
872 | bool JoinHashTable::RequiresPartitioning(ClientConfig &config, vector<unique_ptr<JoinHashTable>> &local_hts) { |
873 | D_ASSERT(total_count != 0); |
874 | D_ASSERT(external); |
875 | |
876 | idx_t num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); |
877 | vector<idx_t> partition_counts(num_partitions, 0); |
878 | vector<idx_t> partition_sizes(num_partitions, 0); |
879 | for (auto &ht : local_hts) { |
880 | const auto &local_partitions = ht->GetSinkCollection().GetPartitions(); |
881 | for (idx_t partition_idx = 0; partition_idx < num_partitions; partition_idx++) { |
882 | auto &local_partition = local_partitions[partition_idx]; |
883 | partition_counts[partition_idx] += local_partition->Count(); |
884 | partition_sizes[partition_idx] += local_partition->SizeInBytes(); |
885 | } |
886 | } |
887 | |
888 | // Figure out if we can fit all single partitions in memory |
889 | idx_t max_partition_idx = 0; |
890 | idx_t max_partition_size = 0; |
891 | for (idx_t partition_idx = 0; partition_idx < num_partitions; partition_idx++) { |
892 | const auto &partition_count = partition_counts[partition_idx]; |
893 | const auto &partition_size = partition_sizes[partition_idx]; |
894 | auto partition_ht_size = partition_size + PointerTableSize(count: partition_count); |
895 | if (partition_ht_size > max_partition_size) { |
896 | max_partition_size = partition_ht_size; |
897 | max_partition_idx = partition_idx; |
898 | } |
899 | } |
900 | |
901 | if (config.force_external || max_partition_size > max_ht_size) { |
902 | const auto partition_count = partition_counts[max_partition_idx]; |
903 | const auto partition_size = partition_sizes[max_partition_idx]; |
904 | |
905 | const auto max_added_bits = 8 - radix_bits; |
906 | idx_t added_bits; |
907 | for (added_bits = 1; added_bits < max_added_bits; added_bits++) { |
908 | double partition_multiplier = RadixPartitioning::NumberOfPartitions(radix_bits: added_bits); |
909 | |
910 | auto new_estimated_count = double(partition_count) / partition_multiplier; |
911 | auto new_estimated_size = double(partition_size) / partition_multiplier; |
912 | auto new_estimated_ht_size = new_estimated_size + PointerTableSize(count: new_estimated_count); |
913 | |
914 | if (new_estimated_ht_size <= double(max_ht_size) / 4) { |
915 | // Aim for an estimated partition size of max_ht_size / 4 |
916 | break; |
917 | } |
918 | } |
919 | radix_bits += added_bits; |
920 | sink_collection = |
921 | make_uniq<RadixPartitionedTupleData>(args&: buffer_manager, args&: layout, args&: radix_bits, args: layout.ColumnCount() - 1); |
922 | return true; |
923 | } else { |
924 | return false; |
925 | } |
926 | } |
927 | |
928 | void JoinHashTable::Partition(JoinHashTable &global_ht) { |
929 | auto new_sink_collection = |
930 | make_uniq<RadixPartitionedTupleData>(args&: buffer_manager, args&: layout, args&: global_ht.radix_bits, args: layout.ColumnCount() - 1); |
931 | sink_collection->Repartition(new_partitioned_data&: *new_sink_collection); |
932 | sink_collection = std::move(new_sink_collection); |
933 | global_ht.Merge(other&: *this); |
934 | } |
935 | |
936 | void JoinHashTable::Reset() { |
937 | data_collection->Reset(); |
938 | finalized = false; |
939 | } |
940 | |
941 | bool JoinHashTable::PrepareExternalFinalize() { |
942 | if (finalized) { |
943 | Reset(); |
944 | } |
945 | |
946 | const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); |
947 | if (partition_end == num_partitions) { |
948 | return false; |
949 | } |
950 | |
951 | // Start where we left off |
952 | auto &partitions = sink_collection->GetPartitions(); |
953 | partition_start = partition_end; |
954 | |
955 | // Determine how many partitions we can do next (at least one) |
956 | idx_t count = 0; |
957 | idx_t data_size = 0; |
958 | idx_t partition_idx; |
959 | for (partition_idx = partition_start; partition_idx < num_partitions; partition_idx++) { |
960 | auto incl_count = count + partitions[partition_idx]->Count(); |
961 | auto incl_data_size = data_size + partitions[partition_idx]->SizeInBytes(); |
962 | auto incl_ht_size = incl_data_size + PointerTableSize(count: incl_count); |
963 | if (count > 0 && incl_ht_size > max_ht_size) { |
964 | break; |
965 | } |
966 | count = incl_count; |
967 | data_size = incl_data_size; |
968 | } |
969 | partition_end = partition_idx; |
970 | |
971 | // Move the partitions to the main data collection |
972 | for (partition_idx = partition_start; partition_idx < partition_end; partition_idx++) { |
973 | data_collection->Combine(other&: *partitions[partition_idx]); |
974 | } |
975 | D_ASSERT(Count() == count); |
976 | |
977 | return true; |
978 | } |
979 | |
980 | static void CreateSpillChunk(DataChunk &spill_chunk, DataChunk &keys, DataChunk &payload, Vector &hashes) { |
981 | spill_chunk.Reset(); |
982 | idx_t spill_col_idx = 0; |
983 | for (idx_t col_idx = 0; col_idx < keys.ColumnCount(); col_idx++) { |
984 | spill_chunk.data[col_idx].Reference(other&: keys.data[col_idx]); |
985 | } |
986 | spill_col_idx += keys.ColumnCount(); |
987 | for (idx_t col_idx = 0; col_idx < payload.data.size(); col_idx++) { |
988 | spill_chunk.data[spill_col_idx + col_idx].Reference(other&: payload.data[col_idx]); |
989 | } |
990 | spill_col_idx += payload.ColumnCount(); |
991 | spill_chunk.data[spill_col_idx].Reference(other&: hashes); |
992 | } |
993 | |
994 | unique_ptr<ScanStructure> JoinHashTable::ProbeAndSpill(DataChunk &keys, DataChunk &payload, ProbeSpill &probe_spill, |
995 | ProbeSpillLocalAppendState &spill_state, |
996 | DataChunk &spill_chunk) { |
997 | // hash all the keys |
998 | Vector hashes(LogicalType::HASH); |
999 | Hash(keys, sel: *FlatVector::IncrementalSelectionVector(), count: keys.size(), hashes); |
1000 | |
1001 | // find out which keys we can match with the current pinned partitions |
1002 | SelectionVector true_sel; |
1003 | SelectionVector false_sel; |
1004 | true_sel.Initialize(); |
1005 | false_sel.Initialize(); |
1006 | auto true_count = RadixPartitioning::Select(hashes, sel: FlatVector::IncrementalSelectionVector(), count: keys.size(), |
1007 | radix_bits, cutoff: partition_end, true_sel: &true_sel, false_sel: &false_sel); |
1008 | auto false_count = keys.size() - true_count; |
1009 | |
1010 | CreateSpillChunk(spill_chunk, keys, payload, hashes); |
1011 | |
1012 | // can't probe these values right now, append to spill |
1013 | spill_chunk.Slice(sel_vector: false_sel, count: false_count); |
1014 | spill_chunk.Verify(); |
1015 | probe_spill.Append(chunk&: spill_chunk, local_state&: spill_state); |
1016 | |
1017 | // slice the stuff we CAN probe right now |
1018 | hashes.Slice(sel: true_sel, count: true_count); |
1019 | keys.Slice(sel_vector: true_sel, count: true_count); |
1020 | payload.Slice(sel_vector: true_sel, count: true_count); |
1021 | |
1022 | const SelectionVector *current_sel; |
1023 | auto ss = InitializeScanStructure(keys, current_sel); |
1024 | if (ss->count == 0) { |
1025 | return ss; |
1026 | } |
1027 | |
1028 | // now initialize the pointers of the scan structure based on the hashes |
1029 | ApplyBitmask(hashes, sel: *current_sel, count: ss->count, pointers&: ss->pointers); |
1030 | |
1031 | // create the selection vector linking to only non-empty entries |
1032 | ss->InitializeSelectionVector(current_sel); |
1033 | |
1034 | return ss; |
1035 | } |
1036 | |
1037 | ProbeSpill::ProbeSpill(JoinHashTable &ht, ClientContext &context, const vector<LogicalType> &probe_types) |
1038 | : ht(ht), context(context), probe_types(probe_types) { |
1039 | auto remaining_count = ht.GetSinkCollection().Count(); |
1040 | auto remaining_data_size = ht.GetSinkCollection().SizeInBytes(); |
1041 | auto remaining_ht_size = remaining_data_size + ht.PointerTableSize(count: remaining_count); |
1042 | if (remaining_ht_size <= ht.max_ht_size) { |
1043 | // No need to partition as we will only have one more probe round |
1044 | partitioned = false; |
1045 | } else { |
1046 | // More than one probe round to go, so we need to partition |
1047 | partitioned = true; |
1048 | global_partitions = |
1049 | make_uniq<RadixPartitionedColumnData>(args&: context, args: probe_types, args&: ht.radix_bits, args: probe_types.size() - 1); |
1050 | } |
1051 | column_ids.reserve(n: probe_types.size()); |
1052 | for (column_t column_id = 0; column_id < probe_types.size(); column_id++) { |
1053 | column_ids.emplace_back(args&: column_id); |
1054 | } |
1055 | } |
1056 | |
1057 | ProbeSpillLocalState ProbeSpill::RegisterThread() { |
1058 | ProbeSpillLocalAppendState result; |
1059 | lock_guard<mutex> guard(lock); |
1060 | if (partitioned) { |
1061 | local_partitions.emplace_back(args: global_partitions->CreateShared()); |
1062 | local_partition_append_states.emplace_back(args: make_uniq<PartitionedColumnDataAppendState>()); |
1063 | local_partitions.back()->InitializeAppendState(state&: *local_partition_append_states.back()); |
1064 | |
1065 | result.local_partition = local_partitions.back().get(); |
1066 | result.local_partition_append_state = local_partition_append_states.back().get(); |
1067 | } else { |
1068 | local_spill_collections.emplace_back( |
1069 | args: make_uniq<ColumnDataCollection>(args&: BufferManager::GetBufferManager(context), args: probe_types)); |
1070 | local_spill_append_states.emplace_back(args: make_uniq<ColumnDataAppendState>()); |
1071 | local_spill_collections.back()->InitializeAppend(state&: *local_spill_append_states.back()); |
1072 | |
1073 | result.local_spill_collection = local_spill_collections.back().get(); |
1074 | result.local_spill_append_state = local_spill_append_states.back().get(); |
1075 | } |
1076 | return result; |
1077 | } |
1078 | |
1079 | void ProbeSpill::Append(DataChunk &chunk, ProbeSpillLocalAppendState &local_state) { |
1080 | if (partitioned) { |
1081 | local_state.local_partition->Append(state&: *local_state.local_partition_append_state, input&: chunk); |
1082 | } else { |
1083 | local_state.local_spill_collection->Append(state&: *local_state.local_spill_append_state, new_chunk&: chunk); |
1084 | } |
1085 | } |
1086 | |
1087 | void ProbeSpill::Finalize() { |
1088 | if (partitioned) { |
1089 | D_ASSERT(local_partitions.size() == local_partition_append_states.size()); |
1090 | for (idx_t i = 0; i < local_partition_append_states.size(); i++) { |
1091 | local_partitions[i]->FlushAppendState(state&: *local_partition_append_states[i]); |
1092 | } |
1093 | for (auto &local_partition : local_partitions) { |
1094 | global_partitions->Combine(other&: *local_partition); |
1095 | } |
1096 | local_partitions.clear(); |
1097 | local_partition_append_states.clear(); |
1098 | } else { |
1099 | if (local_spill_collections.empty()) { |
1100 | global_spill_collection = |
1101 | make_uniq<ColumnDataCollection>(args&: BufferManager::GetBufferManager(context), args: probe_types); |
1102 | } else { |
1103 | global_spill_collection = std::move(local_spill_collections[0]); |
1104 | for (idx_t i = 1; i < local_spill_collections.size(); i++) { |
1105 | global_spill_collection->Combine(other&: *local_spill_collections[i]); |
1106 | } |
1107 | } |
1108 | local_spill_collections.clear(); |
1109 | local_spill_append_states.clear(); |
1110 | } |
1111 | } |
1112 | |
1113 | void ProbeSpill::PrepareNextProbe() { |
1114 | if (partitioned) { |
1115 | auto &partitions = global_partitions->GetPartitions(); |
1116 | if (partitions.empty() || ht.partition_start == partitions.size()) { |
1117 | // Can't probe, just make an empty one |
1118 | global_spill_collection = |
1119 | make_uniq<ColumnDataCollection>(args&: BufferManager::GetBufferManager(context), args: probe_types); |
1120 | } else { |
1121 | // Move specific partitions to the global spill collection |
1122 | global_spill_collection = std::move(partitions[ht.partition_start]); |
1123 | for (idx_t i = ht.partition_start + 1; i < ht.partition_end; i++) { |
1124 | auto &partition = partitions[i]; |
1125 | if (global_spill_collection->Count() == 0) { |
1126 | global_spill_collection = std::move(partition); |
1127 | } else { |
1128 | global_spill_collection->Combine(other&: *partition); |
1129 | } |
1130 | } |
1131 | } |
1132 | } |
1133 | consumer = make_uniq<ColumnDataConsumer>(args&: *global_spill_collection, args&: column_ids); |
1134 | consumer->InitializeScan(); |
1135 | } |
1136 | |
1137 | } // namespace duckdb |
1138 | |