1 | #include "duckdb/execution/operator/join/physical_hash_join.hpp" |
2 | |
3 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
4 | #include "duckdb/execution/expression_executor.hpp" |
5 | #include "duckdb/function/aggregate/distributive_functions.hpp" |
6 | #include "duckdb/function/function_binder.hpp" |
7 | #include "duckdb/main/client_context.hpp" |
8 | #include "duckdb/main/query_profiler.hpp" |
9 | #include "duckdb/parallel/base_pipeline_event.hpp" |
10 | #include "duckdb/parallel/pipeline.hpp" |
11 | #include "duckdb/parallel/thread_context.hpp" |
12 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
13 | #include "duckdb/planner/expression/bound_reference_expression.hpp" |
14 | #include "duckdb/storage/buffer_manager.hpp" |
15 | #include "duckdb/storage/storage_manager.hpp" |
16 | |
17 | namespace duckdb { |
18 | |
19 | PhysicalHashJoin::PhysicalHashJoin(LogicalOperator &op, unique_ptr<PhysicalOperator> left, |
20 | unique_ptr<PhysicalOperator> right, vector<JoinCondition> cond, JoinType join_type, |
21 | const vector<idx_t> &left_projection_map, |
22 | const vector<idx_t> &right_projection_map_p, vector<LogicalType> delim_types, |
23 | idx_t estimated_cardinality, PerfectHashJoinStats perfect_join_stats) |
24 | : PhysicalComparisonJoin(op, PhysicalOperatorType::HASH_JOIN, std::move(cond), join_type, estimated_cardinality), |
25 | right_projection_map(right_projection_map_p), delim_types(std::move(delim_types)), |
26 | perfect_join_statistics(std::move(perfect_join_stats)) { |
27 | |
28 | children.push_back(x: std::move(left)); |
29 | children.push_back(x: std::move(right)); |
30 | |
31 | D_ASSERT(left_projection_map.empty()); |
32 | for (auto &condition : conditions) { |
33 | condition_types.push_back(x: condition.left->return_type); |
34 | } |
35 | |
36 | // for ANTI, SEMI and MARK join, we only need to store the keys, so for these the build types are empty |
37 | if (join_type != JoinType::ANTI && join_type != JoinType::SEMI && join_type != JoinType::MARK) { |
38 | build_types = LogicalOperator::MapTypes(types: children[1]->GetTypes(), projection_map: right_projection_map); |
39 | } |
40 | } |
41 | |
42 | PhysicalHashJoin::PhysicalHashJoin(LogicalOperator &op, unique_ptr<PhysicalOperator> left, |
43 | unique_ptr<PhysicalOperator> right, vector<JoinCondition> cond, JoinType join_type, |
44 | idx_t estimated_cardinality, PerfectHashJoinStats perfect_join_state) |
45 | : PhysicalHashJoin(op, std::move(left), std::move(right), std::move(cond), join_type, {}, {}, {}, |
46 | estimated_cardinality, std::move(perfect_join_state)) { |
47 | } |
48 | |
49 | //===--------------------------------------------------------------------===// |
50 | // Sink |
51 | //===--------------------------------------------------------------------===// |
52 | class HashJoinGlobalSinkState : public GlobalSinkState { |
53 | public: |
54 | HashJoinGlobalSinkState(const PhysicalHashJoin &op, ClientContext &context_p) |
55 | : context(context_p), finalized(false), scanned_data(false) { |
56 | hash_table = op.InitializeHashTable(context); |
57 | |
58 | // for perfect hash join |
59 | perfect_join_executor = make_uniq<PerfectHashJoinExecutor>(args: op, args&: *hash_table, args: op.perfect_join_statistics); |
60 | // for external hash join |
61 | external = ClientConfig::GetConfig(context).force_external; |
62 | // Set probe types |
63 | const auto &payload_types = op.children[0]->types; |
64 | probe_types.insert(position: probe_types.end(), first: op.condition_types.begin(), last: op.condition_types.end()); |
65 | probe_types.insert(position: probe_types.end(), first: payload_types.begin(), last: payload_types.end()); |
66 | probe_types.emplace_back(args: LogicalType::HASH); |
67 | } |
68 | |
69 | void ScheduleFinalize(Pipeline &pipeline, Event &event); |
70 | void InitializeProbeSpill(); |
71 | |
72 | public: |
73 | ClientContext &context; |
74 | //! Global HT used by the join |
75 | unique_ptr<JoinHashTable> hash_table; |
76 | //! The perfect hash join executor (if any) |
77 | unique_ptr<PerfectHashJoinExecutor> perfect_join_executor; |
78 | //! Whether or not the hash table has been finalized |
79 | bool finalized = false; |
80 | |
81 | //! Whether we are doing an external join |
82 | bool external; |
83 | |
84 | //! Hash tables built by each thread |
85 | mutex lock; |
86 | vector<unique_ptr<JoinHashTable>> local_hash_tables; |
87 | |
88 | //! Excess probe data gathered during Sink |
89 | vector<LogicalType> probe_types; |
90 | unique_ptr<JoinHashTable::ProbeSpill> probe_spill; |
91 | |
92 | //! Whether or not we have started scanning data using GetData |
93 | atomic<bool> scanned_data; |
94 | }; |
95 | |
96 | class HashJoinLocalSinkState : public LocalSinkState { |
97 | public: |
98 | HashJoinLocalSinkState(const PhysicalHashJoin &op, ClientContext &context) : build_executor(context) { |
99 | auto &allocator = Allocator::Get(context); |
100 | if (!op.right_projection_map.empty()) { |
101 | build_chunk.Initialize(allocator, types: op.build_types); |
102 | } |
103 | for (auto &cond : op.conditions) { |
104 | build_executor.AddExpression(expr: *cond.right); |
105 | } |
106 | join_keys.Initialize(allocator, types: op.condition_types); |
107 | |
108 | hash_table = op.InitializeHashTable(context); |
109 | |
110 | hash_table->GetSinkCollection().InitializeAppendState(state&: append_state); |
111 | } |
112 | |
113 | public: |
114 | PartitionedTupleDataAppendState append_state; |
115 | |
116 | DataChunk build_chunk; |
117 | DataChunk join_keys; |
118 | ExpressionExecutor build_executor; |
119 | |
120 | //! Thread-local HT |
121 | unique_ptr<JoinHashTable> hash_table; |
122 | }; |
123 | |
124 | unique_ptr<JoinHashTable> PhysicalHashJoin::InitializeHashTable(ClientContext &context) const { |
125 | auto result = |
126 | make_uniq<JoinHashTable>(args&: BufferManager::GetBufferManager(context), args: conditions, args: build_types, args: join_type); |
127 | result->max_ht_size = double(BufferManager::GetBufferManager(context).GetMaxMemory()) * 0.6; |
128 | if (!delim_types.empty() && join_type == JoinType::MARK) { |
129 | // correlated MARK join |
130 | if (delim_types.size() + 1 == conditions.size()) { |
131 | // the correlated MARK join has one more condition than the amount of correlated columns |
132 | // this is the case in a correlated ANY() expression |
133 | // in this case we need to keep track of additional entries, namely: |
134 | // - (1) the total amount of elements per group |
135 | // - (2) the amount of non-null elements per group |
136 | // we need these to correctly deal with the cases of either: |
137 | // - (1) the group being empty [in which case the result is always false, even if the comparison is NULL] |
138 | // - (2) the group containing a NULL value [in which case FALSE becomes NULL] |
139 | auto &info = result->correlated_mark_join_info; |
140 | |
141 | vector<LogicalType> payload_types; |
142 | vector<BoundAggregateExpression *> correlated_aggregates; |
143 | unique_ptr<BoundAggregateExpression> aggr; |
144 | |
145 | // jury-rigging the GroupedAggregateHashTable |
146 | // we need a count_star and a count to get counts with and without NULLs |
147 | |
148 | FunctionBinder function_binder(context); |
149 | aggr = function_binder.BindAggregateFunction(bound_function: CountStarFun::GetFunction(), children: {}, filter: nullptr, |
150 | aggr_type: AggregateType::NON_DISTINCT); |
151 | correlated_aggregates.push_back(x: &*aggr); |
152 | payload_types.push_back(x: aggr->return_type); |
153 | info.correlated_aggregates.push_back(x: std::move(aggr)); |
154 | |
155 | auto count_fun = CountFun::GetFunction(); |
156 | vector<unique_ptr<Expression>> children; |
157 | // this is a dummy but we need it to make the hash table understand whats going on |
158 | children.push_back(x: make_uniq_base<Expression, BoundReferenceExpression>(args&: count_fun.return_type, args: 0)); |
159 | aggr = function_binder.BindAggregateFunction(bound_function: count_fun, children: std::move(children), filter: nullptr, |
160 | aggr_type: AggregateType::NON_DISTINCT); |
161 | correlated_aggregates.push_back(x: &*aggr); |
162 | payload_types.push_back(x: aggr->return_type); |
163 | info.correlated_aggregates.push_back(x: std::move(aggr)); |
164 | |
165 | auto &allocator = Allocator::Get(context); |
166 | info.correlated_counts = make_uniq<GroupedAggregateHashTable>(args&: context, args&: allocator, args: delim_types, |
167 | args&: payload_types, args&: correlated_aggregates); |
168 | info.correlated_types = delim_types; |
169 | info.group_chunk.Initialize(allocator, types: delim_types); |
170 | info.result_chunk.Initialize(allocator, types: payload_types); |
171 | } |
172 | } |
173 | return result; |
174 | } |
175 | |
176 | unique_ptr<GlobalSinkState> PhysicalHashJoin::GetGlobalSinkState(ClientContext &context) const { |
177 | return make_uniq<HashJoinGlobalSinkState>(args: *this, args&: context); |
178 | } |
179 | |
180 | unique_ptr<LocalSinkState> PhysicalHashJoin::GetLocalSinkState(ExecutionContext &context) const { |
181 | return make_uniq<HashJoinLocalSinkState>(args: *this, args&: context.client); |
182 | } |
183 | |
184 | SinkResultType PhysicalHashJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { |
185 | auto &lstate = input.local_state.Cast<HashJoinLocalSinkState>(); |
186 | |
187 | // resolve the join keys for the right chunk |
188 | lstate.join_keys.Reset(); |
189 | lstate.build_executor.Execute(input&: chunk, result&: lstate.join_keys); |
190 | |
191 | // build the HT |
192 | auto &ht = *lstate.hash_table; |
193 | if (!right_projection_map.empty()) { |
194 | // there is a projection map: fill the build chunk with the projected columns |
195 | lstate.build_chunk.Reset(); |
196 | lstate.build_chunk.SetCardinality(chunk); |
197 | for (idx_t i = 0; i < right_projection_map.size(); i++) { |
198 | lstate.build_chunk.data[i].Reference(other&: chunk.data[right_projection_map[i]]); |
199 | } |
200 | ht.Build(append_state&: lstate.append_state, keys&: lstate.join_keys, input&: lstate.build_chunk); |
201 | } else if (!build_types.empty()) { |
202 | // there is not a projected map: place the entire right chunk in the HT |
203 | ht.Build(append_state&: lstate.append_state, keys&: lstate.join_keys, input&: chunk); |
204 | } else { |
205 | // there are only keys: place an empty chunk in the payload |
206 | lstate.build_chunk.SetCardinality(chunk.size()); |
207 | ht.Build(append_state&: lstate.append_state, keys&: lstate.join_keys, input&: lstate.build_chunk); |
208 | } |
209 | |
210 | return SinkResultType::NEED_MORE_INPUT; |
211 | } |
212 | |
213 | void PhysicalHashJoin::Combine(ExecutionContext &context, GlobalSinkState &gstate_p, LocalSinkState &lstate_p) const { |
214 | auto &gstate = gstate_p.Cast<HashJoinGlobalSinkState>(); |
215 | auto &lstate = lstate_p.Cast<HashJoinLocalSinkState>(); |
216 | if (lstate.hash_table) { |
217 | lstate.hash_table->GetSinkCollection().FlushAppendState(state&: lstate.append_state); |
218 | lock_guard<mutex> local_ht_lock(gstate.lock); |
219 | gstate.local_hash_tables.push_back(x: std::move(lstate.hash_table)); |
220 | } |
221 | auto &client_profiler = QueryProfiler::Get(context&: context.client); |
222 | context.thread.profiler.Flush(phys_op: *this, expression_executor&: lstate.build_executor, name: "build_executor" , id: 1); |
223 | client_profiler.Flush(profiler&: context.thread.profiler); |
224 | } |
225 | |
226 | //===--------------------------------------------------------------------===// |
227 | // Finalize |
228 | //===--------------------------------------------------------------------===// |
229 | class HashJoinFinalizeTask : public ExecutorTask { |
230 | public: |
231 | HashJoinFinalizeTask(shared_ptr<Event> event_p, ClientContext &context, HashJoinGlobalSinkState &sink_p, |
232 | idx_t chunk_idx_from_p, idx_t chunk_idx_to_p, bool parallel_p) |
233 | : ExecutorTask(context), event(std::move(event_p)), sink(sink_p), chunk_idx_from(chunk_idx_from_p), |
234 | chunk_idx_to(chunk_idx_to_p), parallel(parallel_p) { |
235 | } |
236 | |
237 | TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { |
238 | sink.hash_table->Finalize(chunk_idx_from, chunk_idx_to, parallel); |
239 | event->FinishTask(); |
240 | return TaskExecutionResult::TASK_FINISHED; |
241 | } |
242 | |
243 | private: |
244 | shared_ptr<Event> event; |
245 | HashJoinGlobalSinkState &sink; |
246 | idx_t chunk_idx_from; |
247 | idx_t chunk_idx_to; |
248 | bool parallel; |
249 | }; |
250 | |
251 | class HashJoinFinalizeEvent : public BasePipelineEvent { |
252 | public: |
253 | HashJoinFinalizeEvent(Pipeline &pipeline_p, HashJoinGlobalSinkState &sink) |
254 | : BasePipelineEvent(pipeline_p), sink(sink) { |
255 | } |
256 | |
257 | HashJoinGlobalSinkState &sink; |
258 | |
259 | public: |
260 | void Schedule() override { |
261 | auto &context = pipeline->GetClientContext(); |
262 | |
263 | vector<shared_ptr<Task>> finalize_tasks; |
264 | auto &ht = *sink.hash_table; |
265 | const auto chunk_count = ht.GetDataCollection().ChunkCount(); |
266 | const idx_t num_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); |
267 | if (num_threads == 1 || (ht.Count() < PARALLEL_CONSTRUCT_THRESHOLD && !context.config.verify_parallelism)) { |
268 | // Single-threaded finalize |
269 | finalize_tasks.push_back( |
270 | x: make_uniq<HashJoinFinalizeTask>(args: shared_from_this(), args&: context, args&: sink, args: 0, args: chunk_count, args: false)); |
271 | } else { |
272 | // Parallel finalize |
273 | auto chunks_per_thread = MaxValue<idx_t>(a: (chunk_count + num_threads - 1) / num_threads, b: 1); |
274 | |
275 | idx_t chunk_idx = 0; |
276 | for (idx_t thread_idx = 0; thread_idx < num_threads; thread_idx++) { |
277 | auto chunk_idx_from = chunk_idx; |
278 | auto chunk_idx_to = MinValue<idx_t>(a: chunk_idx_from + chunks_per_thread, b: chunk_count); |
279 | finalize_tasks.push_back(x: make_uniq<HashJoinFinalizeTask>(args: shared_from_this(), args&: context, args&: sink, |
280 | args&: chunk_idx_from, args&: chunk_idx_to, args: true)); |
281 | chunk_idx = chunk_idx_to; |
282 | if (chunk_idx == chunk_count) { |
283 | break; |
284 | } |
285 | } |
286 | } |
287 | SetTasks(std::move(finalize_tasks)); |
288 | } |
289 | |
290 | void FinishEvent() override { |
291 | sink.hash_table->GetDataCollection().VerifyEverythingPinned(); |
292 | sink.hash_table->finalized = true; |
293 | } |
294 | |
295 | static constexpr const idx_t PARALLEL_CONSTRUCT_THRESHOLD = 1048576; |
296 | }; |
297 | |
298 | void HashJoinGlobalSinkState::ScheduleFinalize(Pipeline &pipeline, Event &event) { |
299 | if (hash_table->Count() == 0) { |
300 | hash_table->finalized = true; |
301 | return; |
302 | } |
303 | hash_table->InitializePointerTable(); |
304 | auto new_event = make_shared<HashJoinFinalizeEvent>(args&: pipeline, args&: *this); |
305 | event.InsertEvent(replacement_event: std::move(new_event)); |
306 | } |
307 | |
308 | void HashJoinGlobalSinkState::InitializeProbeSpill() { |
309 | lock_guard<mutex> guard(lock); |
310 | if (!probe_spill) { |
311 | probe_spill = make_uniq<JoinHashTable::ProbeSpill>(args&: *hash_table, args&: context, args&: probe_types); |
312 | } |
313 | } |
314 | |
315 | class HashJoinPartitionTask : public ExecutorTask { |
316 | public: |
317 | HashJoinPartitionTask(shared_ptr<Event> event_p, ClientContext &context, JoinHashTable &global_ht, |
318 | JoinHashTable &local_ht) |
319 | : ExecutorTask(context), event(std::move(event_p)), global_ht(global_ht), local_ht(local_ht) { |
320 | } |
321 | |
322 | TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { |
323 | local_ht.Partition(global_ht); |
324 | event->FinishTask(); |
325 | return TaskExecutionResult::TASK_FINISHED; |
326 | } |
327 | |
328 | private: |
329 | shared_ptr<Event> event; |
330 | |
331 | JoinHashTable &global_ht; |
332 | JoinHashTable &local_ht; |
333 | }; |
334 | |
335 | class HashJoinPartitionEvent : public BasePipelineEvent { |
336 | public: |
337 | HashJoinPartitionEvent(Pipeline &pipeline_p, HashJoinGlobalSinkState &sink, |
338 | vector<unique_ptr<JoinHashTable>> &local_hts) |
339 | : BasePipelineEvent(pipeline_p), sink(sink), local_hts(local_hts) { |
340 | } |
341 | |
342 | HashJoinGlobalSinkState &sink; |
343 | vector<unique_ptr<JoinHashTable>> &local_hts; |
344 | |
345 | public: |
346 | void Schedule() override { |
347 | auto &context = pipeline->GetClientContext(); |
348 | vector<shared_ptr<Task>> partition_tasks; |
349 | partition_tasks.reserve(n: local_hts.size()); |
350 | for (auto &local_ht : local_hts) { |
351 | partition_tasks.push_back( |
352 | x: make_uniq<HashJoinPartitionTask>(args: shared_from_this(), args&: context, args&: *sink.hash_table, args&: *local_ht)); |
353 | } |
354 | SetTasks(std::move(partition_tasks)); |
355 | } |
356 | |
357 | void FinishEvent() override { |
358 | local_hts.clear(); |
359 | sink.hash_table->PrepareExternalFinalize(); |
360 | sink.ScheduleFinalize(pipeline&: *pipeline, event&: *this); |
361 | } |
362 | }; |
363 | |
364 | SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, |
365 | GlobalSinkState &gstate) const { |
366 | auto &sink = gstate.Cast<HashJoinGlobalSinkState>(); |
367 | auto &ht = *sink.hash_table; |
368 | |
369 | sink.external = ht.RequiresExternalJoin(config&: context.config, local_hts&: sink.local_hash_tables); |
370 | if (sink.external) { |
371 | sink.perfect_join_executor.reset(); |
372 | if (ht.RequiresPartitioning(config&: context.config, local_hts&: sink.local_hash_tables)) { |
373 | auto new_event = make_shared<HashJoinPartitionEvent>(args&: pipeline, args&: sink, args&: sink.local_hash_tables); |
374 | event.InsertEvent(replacement_event: std::move(new_event)); |
375 | } else { |
376 | for (auto &local_ht : sink.local_hash_tables) { |
377 | ht.Merge(other&: *local_ht); |
378 | } |
379 | sink.local_hash_tables.clear(); |
380 | sink.hash_table->PrepareExternalFinalize(); |
381 | sink.ScheduleFinalize(pipeline, event); |
382 | } |
383 | sink.finalized = true; |
384 | return SinkFinalizeType::READY; |
385 | } else { |
386 | for (auto &local_ht : sink.local_hash_tables) { |
387 | ht.Merge(other&: *local_ht); |
388 | } |
389 | sink.local_hash_tables.clear(); |
390 | ht.Unpartition(); |
391 | } |
392 | |
393 | // check for possible perfect hash table |
394 | auto use_perfect_hash = sink.perfect_join_executor->CanDoPerfectHashJoin(); |
395 | if (use_perfect_hash) { |
396 | D_ASSERT(ht.equality_types.size() == 1); |
397 | auto key_type = ht.equality_types[0]; |
398 | use_perfect_hash = sink.perfect_join_executor->BuildPerfectHashTable(type&: key_type); |
399 | } |
400 | // In case of a large build side or duplicates, use regular hash join |
401 | if (!use_perfect_hash) { |
402 | sink.perfect_join_executor.reset(); |
403 | sink.ScheduleFinalize(pipeline, event); |
404 | } |
405 | sink.finalized = true; |
406 | if (ht.Count() == 0 && EmptyResultIfRHSIsEmpty()) { |
407 | return SinkFinalizeType::NO_OUTPUT_POSSIBLE; |
408 | } |
409 | return SinkFinalizeType::READY; |
410 | } |
411 | |
412 | //===--------------------------------------------------------------------===// |
413 | // Operator |
414 | //===--------------------------------------------------------------------===// |
415 | class HashJoinOperatorState : public CachingOperatorState { |
416 | public: |
417 | explicit HashJoinOperatorState(ClientContext &context) : probe_executor(context), initialized(false) { |
418 | } |
419 | |
420 | DataChunk join_keys; |
421 | ExpressionExecutor probe_executor; |
422 | unique_ptr<JoinHashTable::ScanStructure> scan_structure; |
423 | unique_ptr<OperatorState> perfect_hash_join_state; |
424 | |
425 | bool initialized; |
426 | JoinHashTable::ProbeSpillLocalAppendState spill_state; |
427 | //! Chunk to sink data into for external join |
428 | DataChunk spill_chunk; |
429 | |
430 | public: |
431 | void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { |
432 | context.thread.profiler.Flush(phys_op: op, expression_executor&: probe_executor, name: "probe_executor" , id: 0); |
433 | } |
434 | }; |
435 | |
436 | unique_ptr<OperatorState> PhysicalHashJoin::GetOperatorState(ExecutionContext &context) const { |
437 | auto &allocator = Allocator::Get(context&: context.client); |
438 | auto &sink = sink_state->Cast<HashJoinGlobalSinkState>(); |
439 | auto state = make_uniq<HashJoinOperatorState>(args&: context.client); |
440 | if (sink.perfect_join_executor) { |
441 | state->perfect_hash_join_state = sink.perfect_join_executor->GetOperatorState(context); |
442 | } else { |
443 | state->join_keys.Initialize(allocator, types: condition_types); |
444 | for (auto &cond : conditions) { |
445 | state->probe_executor.AddExpression(expr: *cond.left); |
446 | } |
447 | } |
448 | if (sink.external) { |
449 | state->spill_chunk.Initialize(allocator, types: sink.probe_types); |
450 | sink.InitializeProbeSpill(); |
451 | } |
452 | |
453 | return std::move(state); |
454 | } |
455 | |
456 | OperatorResultType PhysicalHashJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, |
457 | GlobalOperatorState &gstate, OperatorState &state_p) const { |
458 | auto &state = state_p.Cast<HashJoinOperatorState>(); |
459 | auto &sink = sink_state->Cast<HashJoinGlobalSinkState>(); |
460 | D_ASSERT(sink.finalized); |
461 | D_ASSERT(!sink.scanned_data); |
462 | |
463 | // some initialization for external hash join |
464 | if (sink.external && !state.initialized) { |
465 | if (!sink.probe_spill) { |
466 | sink.InitializeProbeSpill(); |
467 | } |
468 | state.spill_state = sink.probe_spill->RegisterThread(); |
469 | state.initialized = true; |
470 | } |
471 | |
472 | if (sink.hash_table->Count() == 0 && EmptyResultIfRHSIsEmpty()) { |
473 | return OperatorResultType::FINISHED; |
474 | } |
475 | |
476 | if (sink.perfect_join_executor) { |
477 | D_ASSERT(!sink.external); |
478 | return sink.perfect_join_executor->ProbePerfectHashTable(context, input, chunk, state&: *state.perfect_hash_join_state); |
479 | } |
480 | |
481 | if (state.scan_structure) { |
482 | // still have elements remaining (i.e. we got >STANDARD_VECTOR_SIZE elements in the previous probe) |
483 | state.scan_structure->Next(keys&: state.join_keys, left&: input, result&: chunk); |
484 | if (chunk.size() > 0) { |
485 | return OperatorResultType::HAVE_MORE_OUTPUT; |
486 | } |
487 | state.scan_structure = nullptr; |
488 | return OperatorResultType::NEED_MORE_INPUT; |
489 | } |
490 | |
491 | // probe the HT |
492 | if (sink.hash_table->Count() == 0) { |
493 | ConstructEmptyJoinResult(join_type: sink.hash_table->join_type, has_null: sink.hash_table->has_null, input, result&: chunk); |
494 | return OperatorResultType::NEED_MORE_INPUT; |
495 | } |
496 | |
497 | // resolve the join keys for the left chunk |
498 | state.join_keys.Reset(); |
499 | state.probe_executor.Execute(input, result&: state.join_keys); |
500 | |
501 | // perform the actual probe |
502 | if (sink.external) { |
503 | state.scan_structure = sink.hash_table->ProbeAndSpill(keys&: state.join_keys, payload&: input, probe_spill&: *sink.probe_spill, |
504 | spill_state&: state.spill_state, spill_chunk&: state.spill_chunk); |
505 | } else { |
506 | state.scan_structure = sink.hash_table->Probe(keys&: state.join_keys); |
507 | } |
508 | state.scan_structure->Next(keys&: state.join_keys, left&: input, result&: chunk); |
509 | return OperatorResultType::HAVE_MORE_OUTPUT; |
510 | } |
511 | |
512 | //===--------------------------------------------------------------------===// |
513 | // Source |
514 | //===--------------------------------------------------------------------===// |
515 | enum class HashJoinSourceStage : uint8_t { INIT, BUILD, PROBE, SCAN_HT, DONE }; |
516 | |
517 | class HashJoinLocalSourceState; |
518 | |
519 | class HashJoinGlobalSourceState : public GlobalSourceState { |
520 | public: |
521 | HashJoinGlobalSourceState(const PhysicalHashJoin &op, ClientContext &context); |
522 | |
523 | //! Initialize this source state using the info in the sink |
524 | void Initialize(HashJoinGlobalSinkState &sink); |
525 | //! Try to prepare the next stage |
526 | void TryPrepareNextStage(HashJoinGlobalSinkState &sink); |
527 | //! Prepare the next build/probe/scan_ht stage for external hash join (must hold lock) |
528 | void PrepareBuild(HashJoinGlobalSinkState &sink); |
529 | void PrepareProbe(HashJoinGlobalSinkState &sink); |
530 | void PrepareScanHT(HashJoinGlobalSinkState &sink); |
531 | //! Assigns a task to a local source state |
532 | bool AssignTask(HashJoinGlobalSinkState &sink, HashJoinLocalSourceState &lstate); |
533 | |
534 | idx_t MaxThreads() override { |
535 | return probe_count / ((idx_t)STANDARD_VECTOR_SIZE * parallel_scan_chunk_count); |
536 | } |
537 | |
538 | public: |
539 | const PhysicalHashJoin &op; |
540 | |
541 | //! For synchronizing the external hash join |
542 | atomic<HashJoinSourceStage> global_stage; |
543 | mutex lock; |
544 | |
545 | //! For HT build synchronization |
546 | idx_t build_chunk_idx; |
547 | idx_t build_chunk_count; |
548 | idx_t build_chunk_done; |
549 | idx_t build_chunks_per_thread; |
550 | |
551 | //! For probe synchronization |
552 | idx_t probe_chunk_count; |
553 | idx_t probe_chunk_done; |
554 | |
555 | //! To determine the number of threads |
556 | idx_t probe_count; |
557 | idx_t parallel_scan_chunk_count; |
558 | |
559 | //! For full/outer synchronization |
560 | idx_t full_outer_chunk_idx; |
561 | idx_t full_outer_chunk_count; |
562 | idx_t full_outer_chunk_done; |
563 | idx_t full_outer_chunks_per_thread; |
564 | }; |
565 | |
566 | class HashJoinLocalSourceState : public LocalSourceState { |
567 | public: |
568 | HashJoinLocalSourceState(const PhysicalHashJoin &op, Allocator &allocator); |
569 | |
570 | //! Do the work this thread has been assigned |
571 | void ExecuteTask(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk); |
572 | //! Whether this thread has finished the work it has been assigned |
573 | bool TaskFinished(); |
574 | //! Build, probe and scan for external hash join |
575 | void ExternalBuild(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate); |
576 | void ExternalProbe(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk); |
577 | void ExternalScanHT(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk); |
578 | |
579 | public: |
580 | //! The stage that this thread was assigned work for |
581 | HashJoinSourceStage local_stage; |
582 | //! Vector with pointers here so we don't have to re-initialize |
583 | Vector addresses; |
584 | |
585 | //! Chunks assigned to this thread for building the pointer table |
586 | idx_t build_chunk_idx_from; |
587 | idx_t build_chunk_idx_to; |
588 | |
589 | //! Local scan state for probe spill |
590 | ColumnDataConsumerScanState probe_local_scan; |
591 | //! Chunks for holding the scanned probe collection |
592 | DataChunk probe_chunk; |
593 | DataChunk join_keys; |
594 | DataChunk payload; |
595 | //! Column indices to easily reference the join keys/payload columns in probe_chunk |
596 | vector<idx_t> join_key_indices; |
597 | vector<idx_t> payload_indices; |
598 | //! Scan structure for the external probe |
599 | unique_ptr<JoinHashTable::ScanStructure> scan_structure; |
600 | bool empty_ht_probe_in_progress; |
601 | |
602 | //! Chunks assigned to this thread for a full/outer scan |
603 | idx_t full_outer_chunk_idx_from; |
604 | idx_t full_outer_chunk_idx_to; |
605 | unique_ptr<JoinHTScanState> full_outer_scan_state; |
606 | }; |
607 | |
608 | unique_ptr<GlobalSourceState> PhysicalHashJoin::GetGlobalSourceState(ClientContext &context) const { |
609 | return make_uniq<HashJoinGlobalSourceState>(args: *this, args&: context); |
610 | } |
611 | |
612 | unique_ptr<LocalSourceState> PhysicalHashJoin::GetLocalSourceState(ExecutionContext &context, |
613 | GlobalSourceState &gstate) const { |
614 | return make_uniq<HashJoinLocalSourceState>(args: *this, args&: Allocator::Get(context&: context.client)); |
615 | } |
616 | |
617 | HashJoinGlobalSourceState::HashJoinGlobalSourceState(const PhysicalHashJoin &op, ClientContext &context) |
618 | : op(op), global_stage(HashJoinSourceStage::INIT), build_chunk_count(0), build_chunk_done(0), probe_chunk_count(0), |
619 | probe_chunk_done(0), probe_count(op.children[0]->estimated_cardinality), |
620 | parallel_scan_chunk_count(context.config.verify_parallelism ? 1 : 120) { |
621 | } |
622 | |
623 | void HashJoinGlobalSourceState::Initialize(HashJoinGlobalSinkState &sink) { |
624 | lock_guard<mutex> init_lock(lock); |
625 | if (global_stage != HashJoinSourceStage::INIT) { |
626 | // Another thread initialized |
627 | return; |
628 | } |
629 | |
630 | // Finalize the probe spill |
631 | if (sink.probe_spill) { |
632 | sink.probe_spill->Finalize(); |
633 | } |
634 | |
635 | global_stage = HashJoinSourceStage::PROBE; |
636 | TryPrepareNextStage(sink); |
637 | } |
638 | |
639 | void HashJoinGlobalSourceState::TryPrepareNextStage(HashJoinGlobalSinkState &sink) { |
640 | switch (global_stage.load()) { |
641 | case HashJoinSourceStage::BUILD: |
642 | if (build_chunk_done == build_chunk_count) { |
643 | sink.hash_table->GetDataCollection().VerifyEverythingPinned(); |
644 | sink.hash_table->finalized = true; |
645 | PrepareProbe(sink); |
646 | } |
647 | break; |
648 | case HashJoinSourceStage::PROBE: |
649 | if (probe_chunk_done == probe_chunk_count) { |
650 | if (IsRightOuterJoin(type: op.join_type)) { |
651 | PrepareScanHT(sink); |
652 | } else { |
653 | PrepareBuild(sink); |
654 | } |
655 | } |
656 | break; |
657 | case HashJoinSourceStage::SCAN_HT: |
658 | if (full_outer_chunk_done == full_outer_chunk_count) { |
659 | PrepareBuild(sink); |
660 | } |
661 | break; |
662 | default: |
663 | break; |
664 | } |
665 | } |
666 | |
667 | void HashJoinGlobalSourceState::PrepareBuild(HashJoinGlobalSinkState &sink) { |
668 | D_ASSERT(global_stage != HashJoinSourceStage::BUILD); |
669 | auto &ht = *sink.hash_table; |
670 | |
671 | // Try to put the next partitions in the block collection of the HT |
672 | if (!sink.external || !ht.PrepareExternalFinalize()) { |
673 | global_stage = HashJoinSourceStage::DONE; |
674 | return; |
675 | } |
676 | |
677 | auto &data_collection = ht.GetDataCollection(); |
678 | if (data_collection.Count() == 0 && op.EmptyResultIfRHSIsEmpty()) { |
679 | PrepareBuild(sink); |
680 | return; |
681 | } |
682 | |
683 | build_chunk_idx = 0; |
684 | build_chunk_count = data_collection.ChunkCount(); |
685 | build_chunk_done = 0; |
686 | |
687 | auto num_threads = TaskScheduler::GetScheduler(context&: sink.context).NumberOfThreads(); |
688 | build_chunks_per_thread = MaxValue<idx_t>(a: (build_chunk_count + num_threads - 1) / num_threads, b: 1); |
689 | |
690 | ht.InitializePointerTable(); |
691 | |
692 | global_stage = HashJoinSourceStage::BUILD; |
693 | } |
694 | |
695 | void HashJoinGlobalSourceState::PrepareProbe(HashJoinGlobalSinkState &sink) { |
696 | sink.probe_spill->PrepareNextProbe(); |
697 | const auto &consumer = *sink.probe_spill->consumer; |
698 | |
699 | probe_chunk_count = consumer.Count() == 0 ? 0 : consumer.ChunkCount(); |
700 | probe_chunk_done = 0; |
701 | |
702 | global_stage = HashJoinSourceStage::PROBE; |
703 | if (probe_chunk_count == 0) { |
704 | TryPrepareNextStage(sink); |
705 | return; |
706 | } |
707 | } |
708 | |
709 | void HashJoinGlobalSourceState::PrepareScanHT(HashJoinGlobalSinkState &sink) { |
710 | D_ASSERT(global_stage != HashJoinSourceStage::SCAN_HT); |
711 | auto &ht = *sink.hash_table; |
712 | |
713 | auto &data_collection = ht.GetDataCollection(); |
714 | full_outer_chunk_idx = 0; |
715 | full_outer_chunk_count = data_collection.ChunkCount(); |
716 | full_outer_chunk_done = 0; |
717 | |
718 | auto num_threads = TaskScheduler::GetScheduler(context&: sink.context).NumberOfThreads(); |
719 | full_outer_chunks_per_thread = MaxValue<idx_t>(a: (full_outer_chunk_count + num_threads - 1) / num_threads, b: 1); |
720 | |
721 | global_stage = HashJoinSourceStage::SCAN_HT; |
722 | } |
723 | |
724 | bool HashJoinGlobalSourceState::AssignTask(HashJoinGlobalSinkState &sink, HashJoinLocalSourceState &lstate) { |
725 | D_ASSERT(lstate.TaskFinished()); |
726 | |
727 | lock_guard<mutex> guard(lock); |
728 | switch (global_stage.load()) { |
729 | case HashJoinSourceStage::BUILD: |
730 | if (build_chunk_idx != build_chunk_count) { |
731 | lstate.local_stage = global_stage; |
732 | lstate.build_chunk_idx_from = build_chunk_idx; |
733 | build_chunk_idx = MinValue<idx_t>(a: build_chunk_count, b: build_chunk_idx + build_chunks_per_thread); |
734 | lstate.build_chunk_idx_to = build_chunk_idx; |
735 | return true; |
736 | } |
737 | break; |
738 | case HashJoinSourceStage::PROBE: |
739 | if (sink.probe_spill->consumer && sink.probe_spill->consumer->AssignChunk(state&: lstate.probe_local_scan)) { |
740 | lstate.local_stage = global_stage; |
741 | lstate.empty_ht_probe_in_progress = false; |
742 | return true; |
743 | } |
744 | break; |
745 | case HashJoinSourceStage::SCAN_HT: |
746 | if (full_outer_chunk_idx != full_outer_chunk_count) { |
747 | lstate.local_stage = global_stage; |
748 | lstate.full_outer_chunk_idx_from = full_outer_chunk_idx; |
749 | full_outer_chunk_idx = |
750 | MinValue<idx_t>(a: full_outer_chunk_count, b: full_outer_chunk_idx + full_outer_chunks_per_thread); |
751 | lstate.full_outer_chunk_idx_to = full_outer_chunk_idx; |
752 | return true; |
753 | } |
754 | break; |
755 | case HashJoinSourceStage::DONE: |
756 | break; |
757 | default: |
758 | throw InternalException("Unexpected HashJoinSourceStage in AssignTask!" ); |
759 | } |
760 | return false; |
761 | } |
762 | |
763 | HashJoinLocalSourceState::HashJoinLocalSourceState(const PhysicalHashJoin &op, Allocator &allocator) |
764 | : local_stage(HashJoinSourceStage::INIT), addresses(LogicalType::POINTER) { |
765 | auto &chunk_state = probe_local_scan.current_chunk_state; |
766 | chunk_state.properties = ColumnDataScanProperties::ALLOW_ZERO_COPY; |
767 | |
768 | auto &sink = op.sink_state->Cast<HashJoinGlobalSinkState>(); |
769 | probe_chunk.Initialize(allocator, types: sink.probe_types); |
770 | join_keys.Initialize(allocator, types: op.condition_types); |
771 | payload.Initialize(allocator, types: op.children[0]->types); |
772 | |
773 | // Store the indices of the columns to reference them easily |
774 | idx_t col_idx = 0; |
775 | for (; col_idx < op.condition_types.size(); col_idx++) { |
776 | join_key_indices.push_back(x: col_idx); |
777 | } |
778 | for (; col_idx < sink.probe_types.size() - 1; col_idx++) { |
779 | payload_indices.push_back(x: col_idx); |
780 | } |
781 | } |
782 | |
783 | void HashJoinLocalSourceState::ExecuteTask(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, |
784 | DataChunk &chunk) { |
785 | switch (local_stage) { |
786 | case HashJoinSourceStage::BUILD: |
787 | ExternalBuild(sink, gstate); |
788 | break; |
789 | case HashJoinSourceStage::PROBE: |
790 | ExternalProbe(sink, gstate, chunk); |
791 | break; |
792 | case HashJoinSourceStage::SCAN_HT: |
793 | ExternalScanHT(sink, gstate, chunk); |
794 | break; |
795 | default: |
796 | throw InternalException("Unexpected HashJoinSourceStage in ExecuteTask!" ); |
797 | } |
798 | } |
799 | |
800 | bool HashJoinLocalSourceState::TaskFinished() { |
801 | switch (local_stage) { |
802 | case HashJoinSourceStage::INIT: |
803 | case HashJoinSourceStage::BUILD: |
804 | return true; |
805 | case HashJoinSourceStage::PROBE: |
806 | return scan_structure == nullptr && !empty_ht_probe_in_progress; |
807 | case HashJoinSourceStage::SCAN_HT: |
808 | return full_outer_scan_state == nullptr; |
809 | default: |
810 | throw InternalException("Unexpected HashJoinSourceStage in TaskFinished!" ); |
811 | } |
812 | } |
813 | |
814 | void HashJoinLocalSourceState::ExternalBuild(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate) { |
815 | D_ASSERT(local_stage == HashJoinSourceStage::BUILD); |
816 | |
817 | auto &ht = *sink.hash_table; |
818 | ht.Finalize(chunk_idx_from: build_chunk_idx_from, chunk_idx_to: build_chunk_idx_to, parallel: true); |
819 | |
820 | lock_guard<mutex> guard(gstate.lock); |
821 | gstate.build_chunk_done += build_chunk_idx_to - build_chunk_idx_from; |
822 | } |
823 | |
824 | void HashJoinLocalSourceState::ExternalProbe(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, |
825 | DataChunk &chunk) { |
826 | D_ASSERT(local_stage == HashJoinSourceStage::PROBE && sink.hash_table->finalized); |
827 | |
828 | if (scan_structure) { |
829 | // Still have elements remaining (i.e. we got >STANDARD_VECTOR_SIZE elements in the previous probe) |
830 | scan_structure->Next(keys&: join_keys, left&: payload, result&: chunk); |
831 | if (chunk.size() != 0) { |
832 | return; |
833 | } |
834 | } |
835 | |
836 | if (scan_structure || empty_ht_probe_in_progress) { |
837 | // Previous probe is done |
838 | scan_structure = nullptr; |
839 | empty_ht_probe_in_progress = false; |
840 | sink.probe_spill->consumer->FinishChunk(state&: probe_local_scan); |
841 | lock_guard<mutex> lock(gstate.lock); |
842 | gstate.probe_chunk_done++; |
843 | return; |
844 | } |
845 | |
846 | // Scan input chunk for next probe |
847 | sink.probe_spill->consumer->ScanChunk(state&: probe_local_scan, chunk&: probe_chunk); |
848 | |
849 | // Get the probe chunk columns/hashes |
850 | join_keys.ReferenceColumns(other&: probe_chunk, column_ids: join_key_indices); |
851 | payload.ReferenceColumns(other&: probe_chunk, column_ids: payload_indices); |
852 | auto precomputed_hashes = &probe_chunk.data.back(); |
853 | |
854 | if (sink.hash_table->Count() == 0 && !gstate.op.EmptyResultIfRHSIsEmpty()) { |
855 | gstate.op.ConstructEmptyJoinResult(join_type: sink.hash_table->join_type, has_null: sink.hash_table->has_null, input&: payload, result&: chunk); |
856 | empty_ht_probe_in_progress = true; |
857 | return; |
858 | } |
859 | |
860 | // Perform the probe |
861 | scan_structure = sink.hash_table->Probe(keys&: join_keys, precomputed_hashes); |
862 | scan_structure->Next(keys&: join_keys, left&: payload, result&: chunk); |
863 | } |
864 | |
865 | void HashJoinLocalSourceState::ExternalScanHT(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, |
866 | DataChunk &chunk) { |
867 | D_ASSERT(local_stage == HashJoinSourceStage::SCAN_HT); |
868 | |
869 | if (!full_outer_scan_state) { |
870 | full_outer_scan_state = make_uniq<JoinHTScanState>(args&: sink.hash_table->GetDataCollection(), |
871 | args&: full_outer_chunk_idx_from, args&: full_outer_chunk_idx_to); |
872 | } |
873 | sink.hash_table->ScanFullOuter(state&: *full_outer_scan_state, addresses, result&: chunk); |
874 | |
875 | if (chunk.size() == 0) { |
876 | full_outer_scan_state = nullptr; |
877 | lock_guard<mutex> guard(gstate.lock); |
878 | gstate.full_outer_chunk_done += full_outer_chunk_idx_to - full_outer_chunk_idx_from; |
879 | } |
880 | } |
881 | |
882 | SourceResultType PhysicalHashJoin::GetData(ExecutionContext &context, DataChunk &chunk, |
883 | OperatorSourceInput &input) const { |
884 | auto &sink = sink_state->Cast<HashJoinGlobalSinkState>(); |
885 | auto &gstate = input.global_state.Cast<HashJoinGlobalSourceState>(); |
886 | auto &lstate = input.local_state.Cast<HashJoinLocalSourceState>(); |
887 | sink.scanned_data = true; |
888 | |
889 | if (!sink.external && !IsRightOuterJoin(type: join_type)) { |
890 | return SourceResultType::FINISHED; |
891 | } |
892 | |
893 | if (gstate.global_stage == HashJoinSourceStage::INIT) { |
894 | gstate.Initialize(sink); |
895 | } |
896 | |
897 | // Any call to GetData must produce tuples, otherwise the pipeline executor thinks that we're done |
898 | // Therefore, we loop until we've produced tuples, or until the operator is actually done |
899 | while (gstate.global_stage != HashJoinSourceStage::DONE && chunk.size() == 0) { |
900 | if (!lstate.TaskFinished() || gstate.AssignTask(sink, lstate)) { |
901 | lstate.ExecuteTask(sink, gstate, chunk); |
902 | } else { |
903 | lock_guard<mutex> guard(gstate.lock); |
904 | gstate.TryPrepareNextStage(sink); |
905 | } |
906 | } |
907 | |
908 | return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; |
909 | } |
910 | |
911 | } // namespace duckdb |
912 | |