1#include "duckdb/execution/operator/join/physical_hash_join.hpp"
2
3#include "duckdb/common/vector_operations/vector_operations.hpp"
4#include "duckdb/execution/expression_executor.hpp"
5#include "duckdb/function/aggregate/distributive_functions.hpp"
6#include "duckdb/function/function_binder.hpp"
7#include "duckdb/main/client_context.hpp"
8#include "duckdb/main/query_profiler.hpp"
9#include "duckdb/parallel/base_pipeline_event.hpp"
10#include "duckdb/parallel/pipeline.hpp"
11#include "duckdb/parallel/thread_context.hpp"
12#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
13#include "duckdb/planner/expression/bound_reference_expression.hpp"
14#include "duckdb/storage/buffer_manager.hpp"
15#include "duckdb/storage/storage_manager.hpp"
16
17namespace duckdb {
18
19PhysicalHashJoin::PhysicalHashJoin(LogicalOperator &op, unique_ptr<PhysicalOperator> left,
20 unique_ptr<PhysicalOperator> right, vector<JoinCondition> cond, JoinType join_type,
21 const vector<idx_t> &left_projection_map,
22 const vector<idx_t> &right_projection_map_p, vector<LogicalType> delim_types,
23 idx_t estimated_cardinality, PerfectHashJoinStats perfect_join_stats)
24 : PhysicalComparisonJoin(op, PhysicalOperatorType::HASH_JOIN, std::move(cond), join_type, estimated_cardinality),
25 right_projection_map(right_projection_map_p), delim_types(std::move(delim_types)),
26 perfect_join_statistics(std::move(perfect_join_stats)) {
27
28 children.push_back(x: std::move(left));
29 children.push_back(x: std::move(right));
30
31 D_ASSERT(left_projection_map.empty());
32 for (auto &condition : conditions) {
33 condition_types.push_back(x: condition.left->return_type);
34 }
35
36 // for ANTI, SEMI and MARK join, we only need to store the keys, so for these the build types are empty
37 if (join_type != JoinType::ANTI && join_type != JoinType::SEMI && join_type != JoinType::MARK) {
38 build_types = LogicalOperator::MapTypes(types: children[1]->GetTypes(), projection_map: right_projection_map);
39 }
40}
41
42PhysicalHashJoin::PhysicalHashJoin(LogicalOperator &op, unique_ptr<PhysicalOperator> left,
43 unique_ptr<PhysicalOperator> right, vector<JoinCondition> cond, JoinType join_type,
44 idx_t estimated_cardinality, PerfectHashJoinStats perfect_join_state)
45 : PhysicalHashJoin(op, std::move(left), std::move(right), std::move(cond), join_type, {}, {}, {},
46 estimated_cardinality, std::move(perfect_join_state)) {
47}
48
49//===--------------------------------------------------------------------===//
50// Sink
51//===--------------------------------------------------------------------===//
52class HashJoinGlobalSinkState : public GlobalSinkState {
53public:
54 HashJoinGlobalSinkState(const PhysicalHashJoin &op, ClientContext &context_p)
55 : context(context_p), finalized(false), scanned_data(false) {
56 hash_table = op.InitializeHashTable(context);
57
58 // for perfect hash join
59 perfect_join_executor = make_uniq<PerfectHashJoinExecutor>(args: op, args&: *hash_table, args: op.perfect_join_statistics);
60 // for external hash join
61 external = ClientConfig::GetConfig(context).force_external;
62 // Set probe types
63 const auto &payload_types = op.children[0]->types;
64 probe_types.insert(position: probe_types.end(), first: op.condition_types.begin(), last: op.condition_types.end());
65 probe_types.insert(position: probe_types.end(), first: payload_types.begin(), last: payload_types.end());
66 probe_types.emplace_back(args: LogicalType::HASH);
67 }
68
69 void ScheduleFinalize(Pipeline &pipeline, Event &event);
70 void InitializeProbeSpill();
71
72public:
73 ClientContext &context;
74 //! Global HT used by the join
75 unique_ptr<JoinHashTable> hash_table;
76 //! The perfect hash join executor (if any)
77 unique_ptr<PerfectHashJoinExecutor> perfect_join_executor;
78 //! Whether or not the hash table has been finalized
79 bool finalized = false;
80
81 //! Whether we are doing an external join
82 bool external;
83
84 //! Hash tables built by each thread
85 mutex lock;
86 vector<unique_ptr<JoinHashTable>> local_hash_tables;
87
88 //! Excess probe data gathered during Sink
89 vector<LogicalType> probe_types;
90 unique_ptr<JoinHashTable::ProbeSpill> probe_spill;
91
92 //! Whether or not we have started scanning data using GetData
93 atomic<bool> scanned_data;
94};
95
96class HashJoinLocalSinkState : public LocalSinkState {
97public:
98 HashJoinLocalSinkState(const PhysicalHashJoin &op, ClientContext &context) : build_executor(context) {
99 auto &allocator = Allocator::Get(context);
100 if (!op.right_projection_map.empty()) {
101 build_chunk.Initialize(allocator, types: op.build_types);
102 }
103 for (auto &cond : op.conditions) {
104 build_executor.AddExpression(expr: *cond.right);
105 }
106 join_keys.Initialize(allocator, types: op.condition_types);
107
108 hash_table = op.InitializeHashTable(context);
109
110 hash_table->GetSinkCollection().InitializeAppendState(state&: append_state);
111 }
112
113public:
114 PartitionedTupleDataAppendState append_state;
115
116 DataChunk build_chunk;
117 DataChunk join_keys;
118 ExpressionExecutor build_executor;
119
120 //! Thread-local HT
121 unique_ptr<JoinHashTable> hash_table;
122};
123
124unique_ptr<JoinHashTable> PhysicalHashJoin::InitializeHashTable(ClientContext &context) const {
125 auto result =
126 make_uniq<JoinHashTable>(args&: BufferManager::GetBufferManager(context), args: conditions, args: build_types, args: join_type);
127 result->max_ht_size = double(BufferManager::GetBufferManager(context).GetMaxMemory()) * 0.6;
128 if (!delim_types.empty() && join_type == JoinType::MARK) {
129 // correlated MARK join
130 if (delim_types.size() + 1 == conditions.size()) {
131 // the correlated MARK join has one more condition than the amount of correlated columns
132 // this is the case in a correlated ANY() expression
133 // in this case we need to keep track of additional entries, namely:
134 // - (1) the total amount of elements per group
135 // - (2) the amount of non-null elements per group
136 // we need these to correctly deal with the cases of either:
137 // - (1) the group being empty [in which case the result is always false, even if the comparison is NULL]
138 // - (2) the group containing a NULL value [in which case FALSE becomes NULL]
139 auto &info = result->correlated_mark_join_info;
140
141 vector<LogicalType> payload_types;
142 vector<BoundAggregateExpression *> correlated_aggregates;
143 unique_ptr<BoundAggregateExpression> aggr;
144
145 // jury-rigging the GroupedAggregateHashTable
146 // we need a count_star and a count to get counts with and without NULLs
147
148 FunctionBinder function_binder(context);
149 aggr = function_binder.BindAggregateFunction(bound_function: CountStarFun::GetFunction(), children: {}, filter: nullptr,
150 aggr_type: AggregateType::NON_DISTINCT);
151 correlated_aggregates.push_back(x: &*aggr);
152 payload_types.push_back(x: aggr->return_type);
153 info.correlated_aggregates.push_back(x: std::move(aggr));
154
155 auto count_fun = CountFun::GetFunction();
156 vector<unique_ptr<Expression>> children;
157 // this is a dummy but we need it to make the hash table understand whats going on
158 children.push_back(x: make_uniq_base<Expression, BoundReferenceExpression>(args&: count_fun.return_type, args: 0));
159 aggr = function_binder.BindAggregateFunction(bound_function: count_fun, children: std::move(children), filter: nullptr,
160 aggr_type: AggregateType::NON_DISTINCT);
161 correlated_aggregates.push_back(x: &*aggr);
162 payload_types.push_back(x: aggr->return_type);
163 info.correlated_aggregates.push_back(x: std::move(aggr));
164
165 auto &allocator = Allocator::Get(context);
166 info.correlated_counts = make_uniq<GroupedAggregateHashTable>(args&: context, args&: allocator, args: delim_types,
167 args&: payload_types, args&: correlated_aggregates);
168 info.correlated_types = delim_types;
169 info.group_chunk.Initialize(allocator, types: delim_types);
170 info.result_chunk.Initialize(allocator, types: payload_types);
171 }
172 }
173 return result;
174}
175
176unique_ptr<GlobalSinkState> PhysicalHashJoin::GetGlobalSinkState(ClientContext &context) const {
177 return make_uniq<HashJoinGlobalSinkState>(args: *this, args&: context);
178}
179
180unique_ptr<LocalSinkState> PhysicalHashJoin::GetLocalSinkState(ExecutionContext &context) const {
181 return make_uniq<HashJoinLocalSinkState>(args: *this, args&: context.client);
182}
183
184SinkResultType PhysicalHashJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const {
185 auto &lstate = input.local_state.Cast<HashJoinLocalSinkState>();
186
187 // resolve the join keys for the right chunk
188 lstate.join_keys.Reset();
189 lstate.build_executor.Execute(input&: chunk, result&: lstate.join_keys);
190
191 // build the HT
192 auto &ht = *lstate.hash_table;
193 if (!right_projection_map.empty()) {
194 // there is a projection map: fill the build chunk with the projected columns
195 lstate.build_chunk.Reset();
196 lstate.build_chunk.SetCardinality(chunk);
197 for (idx_t i = 0; i < right_projection_map.size(); i++) {
198 lstate.build_chunk.data[i].Reference(other&: chunk.data[right_projection_map[i]]);
199 }
200 ht.Build(append_state&: lstate.append_state, keys&: lstate.join_keys, input&: lstate.build_chunk);
201 } else if (!build_types.empty()) {
202 // there is not a projected map: place the entire right chunk in the HT
203 ht.Build(append_state&: lstate.append_state, keys&: lstate.join_keys, input&: chunk);
204 } else {
205 // there are only keys: place an empty chunk in the payload
206 lstate.build_chunk.SetCardinality(chunk.size());
207 ht.Build(append_state&: lstate.append_state, keys&: lstate.join_keys, input&: lstate.build_chunk);
208 }
209
210 return SinkResultType::NEED_MORE_INPUT;
211}
212
213void PhysicalHashJoin::Combine(ExecutionContext &context, GlobalSinkState &gstate_p, LocalSinkState &lstate_p) const {
214 auto &gstate = gstate_p.Cast<HashJoinGlobalSinkState>();
215 auto &lstate = lstate_p.Cast<HashJoinLocalSinkState>();
216 if (lstate.hash_table) {
217 lstate.hash_table->GetSinkCollection().FlushAppendState(state&: lstate.append_state);
218 lock_guard<mutex> local_ht_lock(gstate.lock);
219 gstate.local_hash_tables.push_back(x: std::move(lstate.hash_table));
220 }
221 auto &client_profiler = QueryProfiler::Get(context&: context.client);
222 context.thread.profiler.Flush(phys_op: *this, expression_executor&: lstate.build_executor, name: "build_executor", id: 1);
223 client_profiler.Flush(profiler&: context.thread.profiler);
224}
225
226//===--------------------------------------------------------------------===//
227// Finalize
228//===--------------------------------------------------------------------===//
229class HashJoinFinalizeTask : public ExecutorTask {
230public:
231 HashJoinFinalizeTask(shared_ptr<Event> event_p, ClientContext &context, HashJoinGlobalSinkState &sink_p,
232 idx_t chunk_idx_from_p, idx_t chunk_idx_to_p, bool parallel_p)
233 : ExecutorTask(context), event(std::move(event_p)), sink(sink_p), chunk_idx_from(chunk_idx_from_p),
234 chunk_idx_to(chunk_idx_to_p), parallel(parallel_p) {
235 }
236
237 TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override {
238 sink.hash_table->Finalize(chunk_idx_from, chunk_idx_to, parallel);
239 event->FinishTask();
240 return TaskExecutionResult::TASK_FINISHED;
241 }
242
243private:
244 shared_ptr<Event> event;
245 HashJoinGlobalSinkState &sink;
246 idx_t chunk_idx_from;
247 idx_t chunk_idx_to;
248 bool parallel;
249};
250
251class HashJoinFinalizeEvent : public BasePipelineEvent {
252public:
253 HashJoinFinalizeEvent(Pipeline &pipeline_p, HashJoinGlobalSinkState &sink)
254 : BasePipelineEvent(pipeline_p), sink(sink) {
255 }
256
257 HashJoinGlobalSinkState &sink;
258
259public:
260 void Schedule() override {
261 auto &context = pipeline->GetClientContext();
262
263 vector<shared_ptr<Task>> finalize_tasks;
264 auto &ht = *sink.hash_table;
265 const auto chunk_count = ht.GetDataCollection().ChunkCount();
266 const idx_t num_threads = TaskScheduler::GetScheduler(context).NumberOfThreads();
267 if (num_threads == 1 || (ht.Count() < PARALLEL_CONSTRUCT_THRESHOLD && !context.config.verify_parallelism)) {
268 // Single-threaded finalize
269 finalize_tasks.push_back(
270 x: make_uniq<HashJoinFinalizeTask>(args: shared_from_this(), args&: context, args&: sink, args: 0, args: chunk_count, args: false));
271 } else {
272 // Parallel finalize
273 auto chunks_per_thread = MaxValue<idx_t>(a: (chunk_count + num_threads - 1) / num_threads, b: 1);
274
275 idx_t chunk_idx = 0;
276 for (idx_t thread_idx = 0; thread_idx < num_threads; thread_idx++) {
277 auto chunk_idx_from = chunk_idx;
278 auto chunk_idx_to = MinValue<idx_t>(a: chunk_idx_from + chunks_per_thread, b: chunk_count);
279 finalize_tasks.push_back(x: make_uniq<HashJoinFinalizeTask>(args: shared_from_this(), args&: context, args&: sink,
280 args&: chunk_idx_from, args&: chunk_idx_to, args: true));
281 chunk_idx = chunk_idx_to;
282 if (chunk_idx == chunk_count) {
283 break;
284 }
285 }
286 }
287 SetTasks(std::move(finalize_tasks));
288 }
289
290 void FinishEvent() override {
291 sink.hash_table->GetDataCollection().VerifyEverythingPinned();
292 sink.hash_table->finalized = true;
293 }
294
295 static constexpr const idx_t PARALLEL_CONSTRUCT_THRESHOLD = 1048576;
296};
297
298void HashJoinGlobalSinkState::ScheduleFinalize(Pipeline &pipeline, Event &event) {
299 if (hash_table->Count() == 0) {
300 hash_table->finalized = true;
301 return;
302 }
303 hash_table->InitializePointerTable();
304 auto new_event = make_shared<HashJoinFinalizeEvent>(args&: pipeline, args&: *this);
305 event.InsertEvent(replacement_event: std::move(new_event));
306}
307
308void HashJoinGlobalSinkState::InitializeProbeSpill() {
309 lock_guard<mutex> guard(lock);
310 if (!probe_spill) {
311 probe_spill = make_uniq<JoinHashTable::ProbeSpill>(args&: *hash_table, args&: context, args&: probe_types);
312 }
313}
314
315class HashJoinPartitionTask : public ExecutorTask {
316public:
317 HashJoinPartitionTask(shared_ptr<Event> event_p, ClientContext &context, JoinHashTable &global_ht,
318 JoinHashTable &local_ht)
319 : ExecutorTask(context), event(std::move(event_p)), global_ht(global_ht), local_ht(local_ht) {
320 }
321
322 TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override {
323 local_ht.Partition(global_ht);
324 event->FinishTask();
325 return TaskExecutionResult::TASK_FINISHED;
326 }
327
328private:
329 shared_ptr<Event> event;
330
331 JoinHashTable &global_ht;
332 JoinHashTable &local_ht;
333};
334
335class HashJoinPartitionEvent : public BasePipelineEvent {
336public:
337 HashJoinPartitionEvent(Pipeline &pipeline_p, HashJoinGlobalSinkState &sink,
338 vector<unique_ptr<JoinHashTable>> &local_hts)
339 : BasePipelineEvent(pipeline_p), sink(sink), local_hts(local_hts) {
340 }
341
342 HashJoinGlobalSinkState &sink;
343 vector<unique_ptr<JoinHashTable>> &local_hts;
344
345public:
346 void Schedule() override {
347 auto &context = pipeline->GetClientContext();
348 vector<shared_ptr<Task>> partition_tasks;
349 partition_tasks.reserve(n: local_hts.size());
350 for (auto &local_ht : local_hts) {
351 partition_tasks.push_back(
352 x: make_uniq<HashJoinPartitionTask>(args: shared_from_this(), args&: context, args&: *sink.hash_table, args&: *local_ht));
353 }
354 SetTasks(std::move(partition_tasks));
355 }
356
357 void FinishEvent() override {
358 local_hts.clear();
359 sink.hash_table->PrepareExternalFinalize();
360 sink.ScheduleFinalize(pipeline&: *pipeline, event&: *this);
361 }
362};
363
364SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context,
365 GlobalSinkState &gstate) const {
366 auto &sink = gstate.Cast<HashJoinGlobalSinkState>();
367 auto &ht = *sink.hash_table;
368
369 sink.external = ht.RequiresExternalJoin(config&: context.config, local_hts&: sink.local_hash_tables);
370 if (sink.external) {
371 sink.perfect_join_executor.reset();
372 if (ht.RequiresPartitioning(config&: context.config, local_hts&: sink.local_hash_tables)) {
373 auto new_event = make_shared<HashJoinPartitionEvent>(args&: pipeline, args&: sink, args&: sink.local_hash_tables);
374 event.InsertEvent(replacement_event: std::move(new_event));
375 } else {
376 for (auto &local_ht : sink.local_hash_tables) {
377 ht.Merge(other&: *local_ht);
378 }
379 sink.local_hash_tables.clear();
380 sink.hash_table->PrepareExternalFinalize();
381 sink.ScheduleFinalize(pipeline, event);
382 }
383 sink.finalized = true;
384 return SinkFinalizeType::READY;
385 } else {
386 for (auto &local_ht : sink.local_hash_tables) {
387 ht.Merge(other&: *local_ht);
388 }
389 sink.local_hash_tables.clear();
390 ht.Unpartition();
391 }
392
393 // check for possible perfect hash table
394 auto use_perfect_hash = sink.perfect_join_executor->CanDoPerfectHashJoin();
395 if (use_perfect_hash) {
396 D_ASSERT(ht.equality_types.size() == 1);
397 auto key_type = ht.equality_types[0];
398 use_perfect_hash = sink.perfect_join_executor->BuildPerfectHashTable(type&: key_type);
399 }
400 // In case of a large build side or duplicates, use regular hash join
401 if (!use_perfect_hash) {
402 sink.perfect_join_executor.reset();
403 sink.ScheduleFinalize(pipeline, event);
404 }
405 sink.finalized = true;
406 if (ht.Count() == 0 && EmptyResultIfRHSIsEmpty()) {
407 return SinkFinalizeType::NO_OUTPUT_POSSIBLE;
408 }
409 return SinkFinalizeType::READY;
410}
411
412//===--------------------------------------------------------------------===//
413// Operator
414//===--------------------------------------------------------------------===//
415class HashJoinOperatorState : public CachingOperatorState {
416public:
417 explicit HashJoinOperatorState(ClientContext &context) : probe_executor(context), initialized(false) {
418 }
419
420 DataChunk join_keys;
421 ExpressionExecutor probe_executor;
422 unique_ptr<JoinHashTable::ScanStructure> scan_structure;
423 unique_ptr<OperatorState> perfect_hash_join_state;
424
425 bool initialized;
426 JoinHashTable::ProbeSpillLocalAppendState spill_state;
427 //! Chunk to sink data into for external join
428 DataChunk spill_chunk;
429
430public:
431 void Finalize(const PhysicalOperator &op, ExecutionContext &context) override {
432 context.thread.profiler.Flush(phys_op: op, expression_executor&: probe_executor, name: "probe_executor", id: 0);
433 }
434};
435
436unique_ptr<OperatorState> PhysicalHashJoin::GetOperatorState(ExecutionContext &context) const {
437 auto &allocator = Allocator::Get(context&: context.client);
438 auto &sink = sink_state->Cast<HashJoinGlobalSinkState>();
439 auto state = make_uniq<HashJoinOperatorState>(args&: context.client);
440 if (sink.perfect_join_executor) {
441 state->perfect_hash_join_state = sink.perfect_join_executor->GetOperatorState(context);
442 } else {
443 state->join_keys.Initialize(allocator, types: condition_types);
444 for (auto &cond : conditions) {
445 state->probe_executor.AddExpression(expr: *cond.left);
446 }
447 }
448 if (sink.external) {
449 state->spill_chunk.Initialize(allocator, types: sink.probe_types);
450 sink.InitializeProbeSpill();
451 }
452
453 return std::move(state);
454}
455
456OperatorResultType PhysicalHashJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk,
457 GlobalOperatorState &gstate, OperatorState &state_p) const {
458 auto &state = state_p.Cast<HashJoinOperatorState>();
459 auto &sink = sink_state->Cast<HashJoinGlobalSinkState>();
460 D_ASSERT(sink.finalized);
461 D_ASSERT(!sink.scanned_data);
462
463 // some initialization for external hash join
464 if (sink.external && !state.initialized) {
465 if (!sink.probe_spill) {
466 sink.InitializeProbeSpill();
467 }
468 state.spill_state = sink.probe_spill->RegisterThread();
469 state.initialized = true;
470 }
471
472 if (sink.hash_table->Count() == 0 && EmptyResultIfRHSIsEmpty()) {
473 return OperatorResultType::FINISHED;
474 }
475
476 if (sink.perfect_join_executor) {
477 D_ASSERT(!sink.external);
478 return sink.perfect_join_executor->ProbePerfectHashTable(context, input, chunk, state&: *state.perfect_hash_join_state);
479 }
480
481 if (state.scan_structure) {
482 // still have elements remaining (i.e. we got >STANDARD_VECTOR_SIZE elements in the previous probe)
483 state.scan_structure->Next(keys&: state.join_keys, left&: input, result&: chunk);
484 if (chunk.size() > 0) {
485 return OperatorResultType::HAVE_MORE_OUTPUT;
486 }
487 state.scan_structure = nullptr;
488 return OperatorResultType::NEED_MORE_INPUT;
489 }
490
491 // probe the HT
492 if (sink.hash_table->Count() == 0) {
493 ConstructEmptyJoinResult(join_type: sink.hash_table->join_type, has_null: sink.hash_table->has_null, input, result&: chunk);
494 return OperatorResultType::NEED_MORE_INPUT;
495 }
496
497 // resolve the join keys for the left chunk
498 state.join_keys.Reset();
499 state.probe_executor.Execute(input, result&: state.join_keys);
500
501 // perform the actual probe
502 if (sink.external) {
503 state.scan_structure = sink.hash_table->ProbeAndSpill(keys&: state.join_keys, payload&: input, probe_spill&: *sink.probe_spill,
504 spill_state&: state.spill_state, spill_chunk&: state.spill_chunk);
505 } else {
506 state.scan_structure = sink.hash_table->Probe(keys&: state.join_keys);
507 }
508 state.scan_structure->Next(keys&: state.join_keys, left&: input, result&: chunk);
509 return OperatorResultType::HAVE_MORE_OUTPUT;
510}
511
512//===--------------------------------------------------------------------===//
513// Source
514//===--------------------------------------------------------------------===//
515enum class HashJoinSourceStage : uint8_t { INIT, BUILD, PROBE, SCAN_HT, DONE };
516
517class HashJoinLocalSourceState;
518
519class HashJoinGlobalSourceState : public GlobalSourceState {
520public:
521 HashJoinGlobalSourceState(const PhysicalHashJoin &op, ClientContext &context);
522
523 //! Initialize this source state using the info in the sink
524 void Initialize(HashJoinGlobalSinkState &sink);
525 //! Try to prepare the next stage
526 void TryPrepareNextStage(HashJoinGlobalSinkState &sink);
527 //! Prepare the next build/probe/scan_ht stage for external hash join (must hold lock)
528 void PrepareBuild(HashJoinGlobalSinkState &sink);
529 void PrepareProbe(HashJoinGlobalSinkState &sink);
530 void PrepareScanHT(HashJoinGlobalSinkState &sink);
531 //! Assigns a task to a local source state
532 bool AssignTask(HashJoinGlobalSinkState &sink, HashJoinLocalSourceState &lstate);
533
534 idx_t MaxThreads() override {
535 return probe_count / ((idx_t)STANDARD_VECTOR_SIZE * parallel_scan_chunk_count);
536 }
537
538public:
539 const PhysicalHashJoin &op;
540
541 //! For synchronizing the external hash join
542 atomic<HashJoinSourceStage> global_stage;
543 mutex lock;
544
545 //! For HT build synchronization
546 idx_t build_chunk_idx;
547 idx_t build_chunk_count;
548 idx_t build_chunk_done;
549 idx_t build_chunks_per_thread;
550
551 //! For probe synchronization
552 idx_t probe_chunk_count;
553 idx_t probe_chunk_done;
554
555 //! To determine the number of threads
556 idx_t probe_count;
557 idx_t parallel_scan_chunk_count;
558
559 //! For full/outer synchronization
560 idx_t full_outer_chunk_idx;
561 idx_t full_outer_chunk_count;
562 idx_t full_outer_chunk_done;
563 idx_t full_outer_chunks_per_thread;
564};
565
566class HashJoinLocalSourceState : public LocalSourceState {
567public:
568 HashJoinLocalSourceState(const PhysicalHashJoin &op, Allocator &allocator);
569
570 //! Do the work this thread has been assigned
571 void ExecuteTask(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk);
572 //! Whether this thread has finished the work it has been assigned
573 bool TaskFinished();
574 //! Build, probe and scan for external hash join
575 void ExternalBuild(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate);
576 void ExternalProbe(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk);
577 void ExternalScanHT(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk);
578
579public:
580 //! The stage that this thread was assigned work for
581 HashJoinSourceStage local_stage;
582 //! Vector with pointers here so we don't have to re-initialize
583 Vector addresses;
584
585 //! Chunks assigned to this thread for building the pointer table
586 idx_t build_chunk_idx_from;
587 idx_t build_chunk_idx_to;
588
589 //! Local scan state for probe spill
590 ColumnDataConsumerScanState probe_local_scan;
591 //! Chunks for holding the scanned probe collection
592 DataChunk probe_chunk;
593 DataChunk join_keys;
594 DataChunk payload;
595 //! Column indices to easily reference the join keys/payload columns in probe_chunk
596 vector<idx_t> join_key_indices;
597 vector<idx_t> payload_indices;
598 //! Scan structure for the external probe
599 unique_ptr<JoinHashTable::ScanStructure> scan_structure;
600 bool empty_ht_probe_in_progress;
601
602 //! Chunks assigned to this thread for a full/outer scan
603 idx_t full_outer_chunk_idx_from;
604 idx_t full_outer_chunk_idx_to;
605 unique_ptr<JoinHTScanState> full_outer_scan_state;
606};
607
608unique_ptr<GlobalSourceState> PhysicalHashJoin::GetGlobalSourceState(ClientContext &context) const {
609 return make_uniq<HashJoinGlobalSourceState>(args: *this, args&: context);
610}
611
612unique_ptr<LocalSourceState> PhysicalHashJoin::GetLocalSourceState(ExecutionContext &context,
613 GlobalSourceState &gstate) const {
614 return make_uniq<HashJoinLocalSourceState>(args: *this, args&: Allocator::Get(context&: context.client));
615}
616
617HashJoinGlobalSourceState::HashJoinGlobalSourceState(const PhysicalHashJoin &op, ClientContext &context)
618 : op(op), global_stage(HashJoinSourceStage::INIT), build_chunk_count(0), build_chunk_done(0), probe_chunk_count(0),
619 probe_chunk_done(0), probe_count(op.children[0]->estimated_cardinality),
620 parallel_scan_chunk_count(context.config.verify_parallelism ? 1 : 120) {
621}
622
623void HashJoinGlobalSourceState::Initialize(HashJoinGlobalSinkState &sink) {
624 lock_guard<mutex> init_lock(lock);
625 if (global_stage != HashJoinSourceStage::INIT) {
626 // Another thread initialized
627 return;
628 }
629
630 // Finalize the probe spill
631 if (sink.probe_spill) {
632 sink.probe_spill->Finalize();
633 }
634
635 global_stage = HashJoinSourceStage::PROBE;
636 TryPrepareNextStage(sink);
637}
638
639void HashJoinGlobalSourceState::TryPrepareNextStage(HashJoinGlobalSinkState &sink) {
640 switch (global_stage.load()) {
641 case HashJoinSourceStage::BUILD:
642 if (build_chunk_done == build_chunk_count) {
643 sink.hash_table->GetDataCollection().VerifyEverythingPinned();
644 sink.hash_table->finalized = true;
645 PrepareProbe(sink);
646 }
647 break;
648 case HashJoinSourceStage::PROBE:
649 if (probe_chunk_done == probe_chunk_count) {
650 if (IsRightOuterJoin(type: op.join_type)) {
651 PrepareScanHT(sink);
652 } else {
653 PrepareBuild(sink);
654 }
655 }
656 break;
657 case HashJoinSourceStage::SCAN_HT:
658 if (full_outer_chunk_done == full_outer_chunk_count) {
659 PrepareBuild(sink);
660 }
661 break;
662 default:
663 break;
664 }
665}
666
667void HashJoinGlobalSourceState::PrepareBuild(HashJoinGlobalSinkState &sink) {
668 D_ASSERT(global_stage != HashJoinSourceStage::BUILD);
669 auto &ht = *sink.hash_table;
670
671 // Try to put the next partitions in the block collection of the HT
672 if (!sink.external || !ht.PrepareExternalFinalize()) {
673 global_stage = HashJoinSourceStage::DONE;
674 return;
675 }
676
677 auto &data_collection = ht.GetDataCollection();
678 if (data_collection.Count() == 0 && op.EmptyResultIfRHSIsEmpty()) {
679 PrepareBuild(sink);
680 return;
681 }
682
683 build_chunk_idx = 0;
684 build_chunk_count = data_collection.ChunkCount();
685 build_chunk_done = 0;
686
687 auto num_threads = TaskScheduler::GetScheduler(context&: sink.context).NumberOfThreads();
688 build_chunks_per_thread = MaxValue<idx_t>(a: (build_chunk_count + num_threads - 1) / num_threads, b: 1);
689
690 ht.InitializePointerTable();
691
692 global_stage = HashJoinSourceStage::BUILD;
693}
694
695void HashJoinGlobalSourceState::PrepareProbe(HashJoinGlobalSinkState &sink) {
696 sink.probe_spill->PrepareNextProbe();
697 const auto &consumer = *sink.probe_spill->consumer;
698
699 probe_chunk_count = consumer.Count() == 0 ? 0 : consumer.ChunkCount();
700 probe_chunk_done = 0;
701
702 global_stage = HashJoinSourceStage::PROBE;
703 if (probe_chunk_count == 0) {
704 TryPrepareNextStage(sink);
705 return;
706 }
707}
708
709void HashJoinGlobalSourceState::PrepareScanHT(HashJoinGlobalSinkState &sink) {
710 D_ASSERT(global_stage != HashJoinSourceStage::SCAN_HT);
711 auto &ht = *sink.hash_table;
712
713 auto &data_collection = ht.GetDataCollection();
714 full_outer_chunk_idx = 0;
715 full_outer_chunk_count = data_collection.ChunkCount();
716 full_outer_chunk_done = 0;
717
718 auto num_threads = TaskScheduler::GetScheduler(context&: sink.context).NumberOfThreads();
719 full_outer_chunks_per_thread = MaxValue<idx_t>(a: (full_outer_chunk_count + num_threads - 1) / num_threads, b: 1);
720
721 global_stage = HashJoinSourceStage::SCAN_HT;
722}
723
724bool HashJoinGlobalSourceState::AssignTask(HashJoinGlobalSinkState &sink, HashJoinLocalSourceState &lstate) {
725 D_ASSERT(lstate.TaskFinished());
726
727 lock_guard<mutex> guard(lock);
728 switch (global_stage.load()) {
729 case HashJoinSourceStage::BUILD:
730 if (build_chunk_idx != build_chunk_count) {
731 lstate.local_stage = global_stage;
732 lstate.build_chunk_idx_from = build_chunk_idx;
733 build_chunk_idx = MinValue<idx_t>(a: build_chunk_count, b: build_chunk_idx + build_chunks_per_thread);
734 lstate.build_chunk_idx_to = build_chunk_idx;
735 return true;
736 }
737 break;
738 case HashJoinSourceStage::PROBE:
739 if (sink.probe_spill->consumer && sink.probe_spill->consumer->AssignChunk(state&: lstate.probe_local_scan)) {
740 lstate.local_stage = global_stage;
741 lstate.empty_ht_probe_in_progress = false;
742 return true;
743 }
744 break;
745 case HashJoinSourceStage::SCAN_HT:
746 if (full_outer_chunk_idx != full_outer_chunk_count) {
747 lstate.local_stage = global_stage;
748 lstate.full_outer_chunk_idx_from = full_outer_chunk_idx;
749 full_outer_chunk_idx =
750 MinValue<idx_t>(a: full_outer_chunk_count, b: full_outer_chunk_idx + full_outer_chunks_per_thread);
751 lstate.full_outer_chunk_idx_to = full_outer_chunk_idx;
752 return true;
753 }
754 break;
755 case HashJoinSourceStage::DONE:
756 break;
757 default:
758 throw InternalException("Unexpected HashJoinSourceStage in AssignTask!");
759 }
760 return false;
761}
762
763HashJoinLocalSourceState::HashJoinLocalSourceState(const PhysicalHashJoin &op, Allocator &allocator)
764 : local_stage(HashJoinSourceStage::INIT), addresses(LogicalType::POINTER) {
765 auto &chunk_state = probe_local_scan.current_chunk_state;
766 chunk_state.properties = ColumnDataScanProperties::ALLOW_ZERO_COPY;
767
768 auto &sink = op.sink_state->Cast<HashJoinGlobalSinkState>();
769 probe_chunk.Initialize(allocator, types: sink.probe_types);
770 join_keys.Initialize(allocator, types: op.condition_types);
771 payload.Initialize(allocator, types: op.children[0]->types);
772
773 // Store the indices of the columns to reference them easily
774 idx_t col_idx = 0;
775 for (; col_idx < op.condition_types.size(); col_idx++) {
776 join_key_indices.push_back(x: col_idx);
777 }
778 for (; col_idx < sink.probe_types.size() - 1; col_idx++) {
779 payload_indices.push_back(x: col_idx);
780 }
781}
782
783void HashJoinLocalSourceState::ExecuteTask(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate,
784 DataChunk &chunk) {
785 switch (local_stage) {
786 case HashJoinSourceStage::BUILD:
787 ExternalBuild(sink, gstate);
788 break;
789 case HashJoinSourceStage::PROBE:
790 ExternalProbe(sink, gstate, chunk);
791 break;
792 case HashJoinSourceStage::SCAN_HT:
793 ExternalScanHT(sink, gstate, chunk);
794 break;
795 default:
796 throw InternalException("Unexpected HashJoinSourceStage in ExecuteTask!");
797 }
798}
799
800bool HashJoinLocalSourceState::TaskFinished() {
801 switch (local_stage) {
802 case HashJoinSourceStage::INIT:
803 case HashJoinSourceStage::BUILD:
804 return true;
805 case HashJoinSourceStage::PROBE:
806 return scan_structure == nullptr && !empty_ht_probe_in_progress;
807 case HashJoinSourceStage::SCAN_HT:
808 return full_outer_scan_state == nullptr;
809 default:
810 throw InternalException("Unexpected HashJoinSourceStage in TaskFinished!");
811 }
812}
813
814void HashJoinLocalSourceState::ExternalBuild(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate) {
815 D_ASSERT(local_stage == HashJoinSourceStage::BUILD);
816
817 auto &ht = *sink.hash_table;
818 ht.Finalize(chunk_idx_from: build_chunk_idx_from, chunk_idx_to: build_chunk_idx_to, parallel: true);
819
820 lock_guard<mutex> guard(gstate.lock);
821 gstate.build_chunk_done += build_chunk_idx_to - build_chunk_idx_from;
822}
823
824void HashJoinLocalSourceState::ExternalProbe(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate,
825 DataChunk &chunk) {
826 D_ASSERT(local_stage == HashJoinSourceStage::PROBE && sink.hash_table->finalized);
827
828 if (scan_structure) {
829 // Still have elements remaining (i.e. we got >STANDARD_VECTOR_SIZE elements in the previous probe)
830 scan_structure->Next(keys&: join_keys, left&: payload, result&: chunk);
831 if (chunk.size() != 0) {
832 return;
833 }
834 }
835
836 if (scan_structure || empty_ht_probe_in_progress) {
837 // Previous probe is done
838 scan_structure = nullptr;
839 empty_ht_probe_in_progress = false;
840 sink.probe_spill->consumer->FinishChunk(state&: probe_local_scan);
841 lock_guard<mutex> lock(gstate.lock);
842 gstate.probe_chunk_done++;
843 return;
844 }
845
846 // Scan input chunk for next probe
847 sink.probe_spill->consumer->ScanChunk(state&: probe_local_scan, chunk&: probe_chunk);
848
849 // Get the probe chunk columns/hashes
850 join_keys.ReferenceColumns(other&: probe_chunk, column_ids: join_key_indices);
851 payload.ReferenceColumns(other&: probe_chunk, column_ids: payload_indices);
852 auto precomputed_hashes = &probe_chunk.data.back();
853
854 if (sink.hash_table->Count() == 0 && !gstate.op.EmptyResultIfRHSIsEmpty()) {
855 gstate.op.ConstructEmptyJoinResult(join_type: sink.hash_table->join_type, has_null: sink.hash_table->has_null, input&: payload, result&: chunk);
856 empty_ht_probe_in_progress = true;
857 return;
858 }
859
860 // Perform the probe
861 scan_structure = sink.hash_table->Probe(keys&: join_keys, precomputed_hashes);
862 scan_structure->Next(keys&: join_keys, left&: payload, result&: chunk);
863}
864
865void HashJoinLocalSourceState::ExternalScanHT(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate,
866 DataChunk &chunk) {
867 D_ASSERT(local_stage == HashJoinSourceStage::SCAN_HT);
868
869 if (!full_outer_scan_state) {
870 full_outer_scan_state = make_uniq<JoinHTScanState>(args&: sink.hash_table->GetDataCollection(),
871 args&: full_outer_chunk_idx_from, args&: full_outer_chunk_idx_to);
872 }
873 sink.hash_table->ScanFullOuter(state&: *full_outer_scan_state, addresses, result&: chunk);
874
875 if (chunk.size() == 0) {
876 full_outer_scan_state = nullptr;
877 lock_guard<mutex> guard(gstate.lock);
878 gstate.full_outer_chunk_done += full_outer_chunk_idx_to - full_outer_chunk_idx_from;
879 }
880}
881
882SourceResultType PhysicalHashJoin::GetData(ExecutionContext &context, DataChunk &chunk,
883 OperatorSourceInput &input) const {
884 auto &sink = sink_state->Cast<HashJoinGlobalSinkState>();
885 auto &gstate = input.global_state.Cast<HashJoinGlobalSourceState>();
886 auto &lstate = input.local_state.Cast<HashJoinLocalSourceState>();
887 sink.scanned_data = true;
888
889 if (!sink.external && !IsRightOuterJoin(type: join_type)) {
890 return SourceResultType::FINISHED;
891 }
892
893 if (gstate.global_stage == HashJoinSourceStage::INIT) {
894 gstate.Initialize(sink);
895 }
896
897 // Any call to GetData must produce tuples, otherwise the pipeline executor thinks that we're done
898 // Therefore, we loop until we've produced tuples, or until the operator is actually done
899 while (gstate.global_stage != HashJoinSourceStage::DONE && chunk.size() == 0) {
900 if (!lstate.TaskFinished() || gstate.AssignTask(sink, lstate)) {
901 lstate.ExecuteTask(sink, gstate, chunk);
902 } else {
903 lock_guard<mutex> guard(gstate.lock);
904 gstate.TryPrepareNextStage(sink);
905 }
906 }
907
908 return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT;
909}
910
911} // namespace duckdb
912