1 | #include "duckdb/execution/operator/join/physical_range_join.hpp" |
2 | |
3 | #include "duckdb/common/fast_mem.hpp" |
4 | #include "duckdb/common/operator/comparison_operators.hpp" |
5 | #include "duckdb/common/row_operations/row_operations.hpp" |
6 | #include "duckdb/common/sort/comparators.hpp" |
7 | #include "duckdb/common/sort/sort.hpp" |
8 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
9 | #include "duckdb/execution/expression_executor.hpp" |
10 | #include "duckdb/main/client_context.hpp" |
11 | #include "duckdb/parallel/base_pipeline_event.hpp" |
12 | #include "duckdb/parallel/thread_context.hpp" |
13 | |
14 | #include <thread> |
15 | |
16 | namespace duckdb { |
17 | |
18 | PhysicalRangeJoin::LocalSortedTable::LocalSortedTable(ClientContext &context, const PhysicalRangeJoin &op, |
19 | const idx_t child) |
20 | : op(op), executor(context), has_null(0), count(0) { |
21 | // Initialize order clause expression executor and key DataChunk |
22 | vector<LogicalType> types; |
23 | for (const auto &cond : op.conditions) { |
24 | const auto &expr = child ? cond.right : cond.left; |
25 | executor.AddExpression(expr: *expr); |
26 | |
27 | types.push_back(x: expr->return_type); |
28 | } |
29 | auto &allocator = Allocator::Get(context); |
30 | keys.Initialize(allocator, types); |
31 | } |
32 | |
33 | void PhysicalRangeJoin::LocalSortedTable::Sink(DataChunk &input, GlobalSortState &global_sort_state) { |
34 | // Initialize local state (if necessary) |
35 | if (!local_sort_state.initialized) { |
36 | local_sort_state.Initialize(global_sort_state, buffer_manager_p&: global_sort_state.buffer_manager); |
37 | } |
38 | |
39 | // Obtain sorting columns |
40 | keys.Reset(); |
41 | executor.Execute(input, result&: keys); |
42 | |
43 | // Count the NULLs so we can exclude them later |
44 | has_null += MergeNulls(conditions: op.conditions); |
45 | count += keys.size(); |
46 | |
47 | // Only sort the primary key |
48 | DataChunk join_head; |
49 | join_head.data.emplace_back(args&: keys.data[0]); |
50 | join_head.SetCardinality(keys.size()); |
51 | |
52 | // Sink the data into the local sort state |
53 | local_sort_state.SinkChunk(sort&: join_head, payload&: input); |
54 | } |
55 | |
56 | PhysicalRangeJoin::GlobalSortedTable::GlobalSortedTable(ClientContext &context, const vector<BoundOrderByNode> &orders, |
57 | RowLayout &payload_layout) |
58 | : global_sort_state(BufferManager::GetBufferManager(context), orders, payload_layout), has_null(0), count(0), |
59 | memory_per_thread(0) { |
60 | D_ASSERT(orders.size() == 1); |
61 | |
62 | // Set external (can be forced with the PRAGMA) |
63 | auto &config = ClientConfig::GetConfig(context); |
64 | global_sort_state.external = config.force_external; |
65 | memory_per_thread = PhysicalRangeJoin::GetMaxThreadMemory(context); |
66 | } |
67 | |
68 | void PhysicalRangeJoin::GlobalSortedTable::Combine(LocalSortedTable <able) { |
69 | global_sort_state.AddLocalState(local_sort_state&: ltable.local_sort_state); |
70 | has_null += ltable.has_null; |
71 | count += ltable.count; |
72 | } |
73 | |
74 | void PhysicalRangeJoin::GlobalSortedTable::IntializeMatches() { |
75 | found_match = make_unsafe_uniq_array<bool>(n: Count()); |
76 | memset(s: found_match.get(), c: 0, n: sizeof(bool) * Count()); |
77 | } |
78 | |
79 | void PhysicalRangeJoin::GlobalSortedTable::Print() { |
80 | global_sort_state.Print(); |
81 | } |
82 | |
83 | class RangeJoinMergeTask : public ExecutorTask { |
84 | public: |
85 | using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; |
86 | |
87 | public: |
88 | RangeJoinMergeTask(shared_ptr<Event> event_p, ClientContext &context, GlobalSortedTable &table) |
89 | : ExecutorTask(context), event(std::move(event_p)), context(context), table(table) { |
90 | } |
91 | |
92 | TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { |
93 | // Initialize iejoin sorted and iterate until done |
94 | auto &global_sort_state = table.global_sort_state; |
95 | MergeSorter merge_sorter(global_sort_state, BufferManager::GetBufferManager(context)); |
96 | merge_sorter.PerformInMergeRound(); |
97 | event->FinishTask(); |
98 | |
99 | return TaskExecutionResult::TASK_FINISHED; |
100 | } |
101 | |
102 | private: |
103 | shared_ptr<Event> event; |
104 | ClientContext &context; |
105 | GlobalSortedTable &table; |
106 | }; |
107 | |
108 | class RangeJoinMergeEvent : public BasePipelineEvent { |
109 | public: |
110 | using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; |
111 | |
112 | public: |
113 | RangeJoinMergeEvent(GlobalSortedTable &table_p, Pipeline &pipeline_p) |
114 | : BasePipelineEvent(pipeline_p), table(table_p) { |
115 | } |
116 | |
117 | GlobalSortedTable &table; |
118 | |
119 | public: |
120 | void Schedule() override { |
121 | auto &context = pipeline->GetClientContext(); |
122 | |
123 | // Schedule tasks equal to the number of threads, which will each merge multiple partitions |
124 | auto &ts = TaskScheduler::GetScheduler(context); |
125 | idx_t num_threads = ts.NumberOfThreads(); |
126 | |
127 | vector<shared_ptr<Task>> iejoin_tasks; |
128 | for (idx_t tnum = 0; tnum < num_threads; tnum++) { |
129 | iejoin_tasks.push_back(x: make_uniq<RangeJoinMergeTask>(args: shared_from_this(), args&: context, args&: table)); |
130 | } |
131 | SetTasks(std::move(iejoin_tasks)); |
132 | } |
133 | |
134 | void FinishEvent() override { |
135 | auto &global_sort_state = table.global_sort_state; |
136 | |
137 | global_sort_state.CompleteMergeRound(keep_radix_data: true); |
138 | if (global_sort_state.sorted_blocks.size() > 1) { |
139 | // Multiple blocks remaining: Schedule the next round |
140 | table.ScheduleMergeTasks(pipeline&: *pipeline, event&: *this); |
141 | } |
142 | } |
143 | }; |
144 | |
145 | void PhysicalRangeJoin::GlobalSortedTable::ScheduleMergeTasks(Pipeline &pipeline, Event &event) { |
146 | // Initialize global sort state for a round of merging |
147 | global_sort_state.InitializeMergeRound(); |
148 | auto new_event = make_shared<RangeJoinMergeEvent>(args&: *this, args&: pipeline); |
149 | event.InsertEvent(replacement_event: std::move(new_event)); |
150 | } |
151 | |
152 | void PhysicalRangeJoin::GlobalSortedTable::Finalize(Pipeline &pipeline, Event &event) { |
153 | // Prepare for merge sort phase |
154 | global_sort_state.PrepareMergePhase(); |
155 | |
156 | // Start the merge phase or finish if a merge is not necessary |
157 | if (global_sort_state.sorted_blocks.size() > 1) { |
158 | ScheduleMergeTasks(pipeline, event); |
159 | } |
160 | } |
161 | |
162 | PhysicalRangeJoin::PhysicalRangeJoin(LogicalOperator &op, PhysicalOperatorType type, unique_ptr<PhysicalOperator> left, |
163 | unique_ptr<PhysicalOperator> right, vector<JoinCondition> cond, JoinType join_type, |
164 | idx_t estimated_cardinality) |
165 | : PhysicalComparisonJoin(op, type, std::move(cond), join_type, estimated_cardinality) { |
166 | // Reorder the conditions so that ranges are at the front. |
167 | // TODO: use stats to improve the choice? |
168 | // TODO: Prefer fixed length types? |
169 | if (conditions.size() > 1) { |
170 | auto conditions_p = std::move(conditions); |
171 | conditions.resize(new_size: conditions_p.size()); |
172 | idx_t range_position = 0; |
173 | idx_t other_position = conditions_p.size(); |
174 | for (idx_t i = 0; i < conditions_p.size(); ++i) { |
175 | switch (conditions_p[i].comparison) { |
176 | case ExpressionType::COMPARE_LESSTHAN: |
177 | case ExpressionType::COMPARE_LESSTHANOREQUALTO: |
178 | case ExpressionType::COMPARE_GREATERTHAN: |
179 | case ExpressionType::COMPARE_GREATERTHANOREQUALTO: |
180 | conditions[range_position++] = std::move(conditions_p[i]); |
181 | break; |
182 | default: |
183 | conditions[--other_position] = std::move(conditions_p[i]); |
184 | break; |
185 | } |
186 | } |
187 | } |
188 | |
189 | children.push_back(x: std::move(left)); |
190 | children.push_back(x: std::move(right)); |
191 | } |
192 | |
193 | idx_t PhysicalRangeJoin::LocalSortedTable::MergeNulls(const vector<JoinCondition> &conditions) { |
194 | // Merge the validity masks of the comparison keys into the primary |
195 | // Return the number of NULLs in the resulting chunk |
196 | D_ASSERT(keys.ColumnCount() > 0); |
197 | const auto count = keys.size(); |
198 | |
199 | size_t all_constant = 0; |
200 | for (auto &v : keys.data) { |
201 | if (v.GetVectorType() == VectorType::CONSTANT_VECTOR) { |
202 | ++all_constant; |
203 | } |
204 | } |
205 | |
206 | auto &primary = keys.data[0]; |
207 | if (all_constant == keys.data.size()) { |
208 | // Either all NULL or no NULLs |
209 | for (auto &v : keys.data) { |
210 | if (ConstantVector::IsNull(vector: v)) { |
211 | ConstantVector::SetNull(vector&: primary, is_null: true); |
212 | return count; |
213 | } |
214 | } |
215 | return 0; |
216 | } else if (keys.ColumnCount() > 1) { |
217 | // Flatten the primary, as it will need to merge arbitrary validity masks |
218 | primary.Flatten(count); |
219 | auto &pvalidity = FlatVector::Validity(vector&: primary); |
220 | D_ASSERT(keys.ColumnCount() == conditions.size()); |
221 | for (size_t c = 1; c < keys.data.size(); ++c) { |
222 | // Skip comparisons that accept NULLs |
223 | if (conditions[c].comparison == ExpressionType::COMPARE_DISTINCT_FROM) { |
224 | continue; |
225 | } |
226 | // ToUnifiedFormat the rest, as the sort code will do this anyway. |
227 | auto &v = keys.data[c]; |
228 | UnifiedVectorFormat vdata; |
229 | v.ToUnifiedFormat(count, data&: vdata); |
230 | auto &vvalidity = vdata.validity; |
231 | if (vvalidity.AllValid()) { |
232 | continue; |
233 | } |
234 | pvalidity.EnsureWritable(); |
235 | switch (v.GetVectorType()) { |
236 | case VectorType::FLAT_VECTOR: { |
237 | // Merge entire entries |
238 | auto pmask = pvalidity.GetData(); |
239 | const auto entry_count = pvalidity.EntryCount(count); |
240 | for (idx_t entry_idx = 0; entry_idx < entry_count; ++entry_idx) { |
241 | pmask[entry_idx] &= vvalidity.GetValidityEntry(entry_idx); |
242 | } |
243 | break; |
244 | } |
245 | case VectorType::CONSTANT_VECTOR: |
246 | // All or nothing |
247 | if (ConstantVector::IsNull(vector: v)) { |
248 | pvalidity.SetAllInvalid(count); |
249 | return count; |
250 | } |
251 | break; |
252 | default: |
253 | // One by one |
254 | for (idx_t i = 0; i < count; ++i) { |
255 | const auto idx = vdata.sel->get_index(idx: i); |
256 | if (!vvalidity.RowIsValidUnsafe(row_idx: idx)) { |
257 | pvalidity.SetInvalidUnsafe(i); |
258 | } |
259 | } |
260 | break; |
261 | } |
262 | } |
263 | return count - pvalidity.CountValid(count); |
264 | } else { |
265 | return count - VectorOperations::CountNotNull(input&: primary, count); |
266 | } |
267 | } |
268 | |
269 | BufferHandle PhysicalRangeJoin::SliceSortedPayload(DataChunk &payload, GlobalSortState &state, const idx_t block_idx, |
270 | const SelectionVector &result, const idx_t result_count, |
271 | const idx_t left_cols) { |
272 | // There should only be one sorted block if they have been sorted |
273 | D_ASSERT(state.sorted_blocks.size() == 1); |
274 | SBScanState read_state(state.buffer_manager, state); |
275 | read_state.sb = state.sorted_blocks[0].get(); |
276 | auto &sorted_data = *read_state.sb->payload_data; |
277 | |
278 | read_state.SetIndices(block_idx_to: block_idx, entry_idx_to: 0); |
279 | read_state.PinData(sd&: sorted_data); |
280 | const auto data_ptr = read_state.DataPtr(sd&: sorted_data); |
281 | data_ptr_t heap_ptr = nullptr; |
282 | |
283 | // Set up a batch of pointers to scan data from |
284 | Vector addresses(LogicalType::POINTER, result_count); |
285 | auto data_pointers = FlatVector::GetData<data_ptr_t>(vector&: addresses); |
286 | |
287 | // Set up the data pointers for the values that are actually referenced |
288 | const idx_t &row_width = sorted_data.layout.GetRowWidth(); |
289 | |
290 | auto prev_idx = result.get_index(idx: 0); |
291 | SelectionVector gsel(result_count); |
292 | idx_t addr_count = 0; |
293 | gsel.set_index(idx: 0, loc: addr_count); |
294 | data_pointers[addr_count] = data_ptr + prev_idx * row_width; |
295 | for (idx_t i = 1; i < result_count; ++i) { |
296 | const auto row_idx = result.get_index(idx: i); |
297 | if (row_idx != prev_idx) { |
298 | data_pointers[++addr_count] = data_ptr + row_idx * row_width; |
299 | prev_idx = row_idx; |
300 | } |
301 | gsel.set_index(idx: i, loc: addr_count); |
302 | } |
303 | ++addr_count; |
304 | |
305 | // Unswizzle the offsets back to pointers (if needed) |
306 | if (!sorted_data.layout.AllConstant() && state.external) { |
307 | heap_ptr = read_state.payload_heap_handle.Ptr(); |
308 | } |
309 | |
310 | // Deserialize the payload data |
311 | auto sel = FlatVector::IncrementalSelectionVector(); |
312 | for (idx_t col_no = 0; col_no < sorted_data.layout.ColumnCount(); col_no++) { |
313 | auto &col = payload.data[left_cols + col_no]; |
314 | RowOperations::Gather(rows&: addresses, row_sel: *sel, col, col_sel: *sel, count: addr_count, layout: sorted_data.layout, col_no, build_size: 0, heap_ptr); |
315 | col.Slice(sel: gsel, count: result_count); |
316 | } |
317 | |
318 | return std::move(read_state.payload_heap_handle); |
319 | } |
320 | |
321 | idx_t PhysicalRangeJoin::SelectJoinTail(const ExpressionType &condition, Vector &left, Vector &right, |
322 | const SelectionVector *sel, idx_t count, SelectionVector *true_sel) { |
323 | switch (condition) { |
324 | case ExpressionType::COMPARE_NOTEQUAL: |
325 | return VectorOperations::NotEquals(left, right, sel, count, true_sel, false_sel: nullptr); |
326 | case ExpressionType::COMPARE_LESSTHAN: |
327 | return VectorOperations::LessThan(left, right, sel, count, true_sel, false_sel: nullptr); |
328 | case ExpressionType::COMPARE_GREATERTHAN: |
329 | return VectorOperations::GreaterThan(left, right, sel, count, true_sel, false_sel: nullptr); |
330 | case ExpressionType::COMPARE_LESSTHANOREQUALTO: |
331 | return VectorOperations::LessThanEquals(left, right, sel, count, true_sel, false_sel: nullptr); |
332 | case ExpressionType::COMPARE_GREATERTHANOREQUALTO: |
333 | return VectorOperations::GreaterThanEquals(left, right, sel, count, true_sel, false_sel: nullptr); |
334 | case ExpressionType::COMPARE_DISTINCT_FROM: |
335 | return VectorOperations::DistinctFrom(left, right, sel, count, true_sel, false_sel: nullptr); |
336 | case ExpressionType::COMPARE_NOT_DISTINCT_FROM: |
337 | return VectorOperations::NotDistinctFrom(left, right, sel, count, true_sel, false_sel: nullptr); |
338 | case ExpressionType::COMPARE_EQUAL: |
339 | return VectorOperations::Equals(left, right, sel, count, true_sel, false_sel: nullptr); |
340 | default: |
341 | throw InternalException("Unsupported comparison type for PhysicalRangeJoin" ); |
342 | } |
343 | |
344 | return count; |
345 | } |
346 | |
347 | } // namespace duckdb |
348 | |