1#include "duckdb/execution/operator/join/physical_range_join.hpp"
2
3#include "duckdb/common/fast_mem.hpp"
4#include "duckdb/common/operator/comparison_operators.hpp"
5#include "duckdb/common/row_operations/row_operations.hpp"
6#include "duckdb/common/sort/comparators.hpp"
7#include "duckdb/common/sort/sort.hpp"
8#include "duckdb/common/vector_operations/vector_operations.hpp"
9#include "duckdb/execution/expression_executor.hpp"
10#include "duckdb/main/client_context.hpp"
11#include "duckdb/parallel/base_pipeline_event.hpp"
12#include "duckdb/parallel/thread_context.hpp"
13
14#include <thread>
15
16namespace duckdb {
17
18PhysicalRangeJoin::LocalSortedTable::LocalSortedTable(ClientContext &context, const PhysicalRangeJoin &op,
19 const idx_t child)
20 : op(op), executor(context), has_null(0), count(0) {
21 // Initialize order clause expression executor and key DataChunk
22 vector<LogicalType> types;
23 for (const auto &cond : op.conditions) {
24 const auto &expr = child ? cond.right : cond.left;
25 executor.AddExpression(expr: *expr);
26
27 types.push_back(x: expr->return_type);
28 }
29 auto &allocator = Allocator::Get(context);
30 keys.Initialize(allocator, types);
31}
32
33void PhysicalRangeJoin::LocalSortedTable::Sink(DataChunk &input, GlobalSortState &global_sort_state) {
34 // Initialize local state (if necessary)
35 if (!local_sort_state.initialized) {
36 local_sort_state.Initialize(global_sort_state, buffer_manager_p&: global_sort_state.buffer_manager);
37 }
38
39 // Obtain sorting columns
40 keys.Reset();
41 executor.Execute(input, result&: keys);
42
43 // Count the NULLs so we can exclude them later
44 has_null += MergeNulls(conditions: op.conditions);
45 count += keys.size();
46
47 // Only sort the primary key
48 DataChunk join_head;
49 join_head.data.emplace_back(args&: keys.data[0]);
50 join_head.SetCardinality(keys.size());
51
52 // Sink the data into the local sort state
53 local_sort_state.SinkChunk(sort&: join_head, payload&: input);
54}
55
56PhysicalRangeJoin::GlobalSortedTable::GlobalSortedTable(ClientContext &context, const vector<BoundOrderByNode> &orders,
57 RowLayout &payload_layout)
58 : global_sort_state(BufferManager::GetBufferManager(context), orders, payload_layout), has_null(0), count(0),
59 memory_per_thread(0) {
60 D_ASSERT(orders.size() == 1);
61
62 // Set external (can be forced with the PRAGMA)
63 auto &config = ClientConfig::GetConfig(context);
64 global_sort_state.external = config.force_external;
65 memory_per_thread = PhysicalRangeJoin::GetMaxThreadMemory(context);
66}
67
68void PhysicalRangeJoin::GlobalSortedTable::Combine(LocalSortedTable &ltable) {
69 global_sort_state.AddLocalState(local_sort_state&: ltable.local_sort_state);
70 has_null += ltable.has_null;
71 count += ltable.count;
72}
73
74void PhysicalRangeJoin::GlobalSortedTable::IntializeMatches() {
75 found_match = make_unsafe_uniq_array<bool>(n: Count());
76 memset(s: found_match.get(), c: 0, n: sizeof(bool) * Count());
77}
78
79void PhysicalRangeJoin::GlobalSortedTable::Print() {
80 global_sort_state.Print();
81}
82
83class RangeJoinMergeTask : public ExecutorTask {
84public:
85 using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable;
86
87public:
88 RangeJoinMergeTask(shared_ptr<Event> event_p, ClientContext &context, GlobalSortedTable &table)
89 : ExecutorTask(context), event(std::move(event_p)), context(context), table(table) {
90 }
91
92 TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override {
93 // Initialize iejoin sorted and iterate until done
94 auto &global_sort_state = table.global_sort_state;
95 MergeSorter merge_sorter(global_sort_state, BufferManager::GetBufferManager(context));
96 merge_sorter.PerformInMergeRound();
97 event->FinishTask();
98
99 return TaskExecutionResult::TASK_FINISHED;
100 }
101
102private:
103 shared_ptr<Event> event;
104 ClientContext &context;
105 GlobalSortedTable &table;
106};
107
108class RangeJoinMergeEvent : public BasePipelineEvent {
109public:
110 using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable;
111
112public:
113 RangeJoinMergeEvent(GlobalSortedTable &table_p, Pipeline &pipeline_p)
114 : BasePipelineEvent(pipeline_p), table(table_p) {
115 }
116
117 GlobalSortedTable &table;
118
119public:
120 void Schedule() override {
121 auto &context = pipeline->GetClientContext();
122
123 // Schedule tasks equal to the number of threads, which will each merge multiple partitions
124 auto &ts = TaskScheduler::GetScheduler(context);
125 idx_t num_threads = ts.NumberOfThreads();
126
127 vector<shared_ptr<Task>> iejoin_tasks;
128 for (idx_t tnum = 0; tnum < num_threads; tnum++) {
129 iejoin_tasks.push_back(x: make_uniq<RangeJoinMergeTask>(args: shared_from_this(), args&: context, args&: table));
130 }
131 SetTasks(std::move(iejoin_tasks));
132 }
133
134 void FinishEvent() override {
135 auto &global_sort_state = table.global_sort_state;
136
137 global_sort_state.CompleteMergeRound(keep_radix_data: true);
138 if (global_sort_state.sorted_blocks.size() > 1) {
139 // Multiple blocks remaining: Schedule the next round
140 table.ScheduleMergeTasks(pipeline&: *pipeline, event&: *this);
141 }
142 }
143};
144
145void PhysicalRangeJoin::GlobalSortedTable::ScheduleMergeTasks(Pipeline &pipeline, Event &event) {
146 // Initialize global sort state for a round of merging
147 global_sort_state.InitializeMergeRound();
148 auto new_event = make_shared<RangeJoinMergeEvent>(args&: *this, args&: pipeline);
149 event.InsertEvent(replacement_event: std::move(new_event));
150}
151
152void PhysicalRangeJoin::GlobalSortedTable::Finalize(Pipeline &pipeline, Event &event) {
153 // Prepare for merge sort phase
154 global_sort_state.PrepareMergePhase();
155
156 // Start the merge phase or finish if a merge is not necessary
157 if (global_sort_state.sorted_blocks.size() > 1) {
158 ScheduleMergeTasks(pipeline, event);
159 }
160}
161
162PhysicalRangeJoin::PhysicalRangeJoin(LogicalOperator &op, PhysicalOperatorType type, unique_ptr<PhysicalOperator> left,
163 unique_ptr<PhysicalOperator> right, vector<JoinCondition> cond, JoinType join_type,
164 idx_t estimated_cardinality)
165 : PhysicalComparisonJoin(op, type, std::move(cond), join_type, estimated_cardinality) {
166 // Reorder the conditions so that ranges are at the front.
167 // TODO: use stats to improve the choice?
168 // TODO: Prefer fixed length types?
169 if (conditions.size() > 1) {
170 auto conditions_p = std::move(conditions);
171 conditions.resize(new_size: conditions_p.size());
172 idx_t range_position = 0;
173 idx_t other_position = conditions_p.size();
174 for (idx_t i = 0; i < conditions_p.size(); ++i) {
175 switch (conditions_p[i].comparison) {
176 case ExpressionType::COMPARE_LESSTHAN:
177 case ExpressionType::COMPARE_LESSTHANOREQUALTO:
178 case ExpressionType::COMPARE_GREATERTHAN:
179 case ExpressionType::COMPARE_GREATERTHANOREQUALTO:
180 conditions[range_position++] = std::move(conditions_p[i]);
181 break;
182 default:
183 conditions[--other_position] = std::move(conditions_p[i]);
184 break;
185 }
186 }
187 }
188
189 children.push_back(x: std::move(left));
190 children.push_back(x: std::move(right));
191}
192
193idx_t PhysicalRangeJoin::LocalSortedTable::MergeNulls(const vector<JoinCondition> &conditions) {
194 // Merge the validity masks of the comparison keys into the primary
195 // Return the number of NULLs in the resulting chunk
196 D_ASSERT(keys.ColumnCount() > 0);
197 const auto count = keys.size();
198
199 size_t all_constant = 0;
200 for (auto &v : keys.data) {
201 if (v.GetVectorType() == VectorType::CONSTANT_VECTOR) {
202 ++all_constant;
203 }
204 }
205
206 auto &primary = keys.data[0];
207 if (all_constant == keys.data.size()) {
208 // Either all NULL or no NULLs
209 for (auto &v : keys.data) {
210 if (ConstantVector::IsNull(vector: v)) {
211 ConstantVector::SetNull(vector&: primary, is_null: true);
212 return count;
213 }
214 }
215 return 0;
216 } else if (keys.ColumnCount() > 1) {
217 // Flatten the primary, as it will need to merge arbitrary validity masks
218 primary.Flatten(count);
219 auto &pvalidity = FlatVector::Validity(vector&: primary);
220 D_ASSERT(keys.ColumnCount() == conditions.size());
221 for (size_t c = 1; c < keys.data.size(); ++c) {
222 // Skip comparisons that accept NULLs
223 if (conditions[c].comparison == ExpressionType::COMPARE_DISTINCT_FROM) {
224 continue;
225 }
226 // ToUnifiedFormat the rest, as the sort code will do this anyway.
227 auto &v = keys.data[c];
228 UnifiedVectorFormat vdata;
229 v.ToUnifiedFormat(count, data&: vdata);
230 auto &vvalidity = vdata.validity;
231 if (vvalidity.AllValid()) {
232 continue;
233 }
234 pvalidity.EnsureWritable();
235 switch (v.GetVectorType()) {
236 case VectorType::FLAT_VECTOR: {
237 // Merge entire entries
238 auto pmask = pvalidity.GetData();
239 const auto entry_count = pvalidity.EntryCount(count);
240 for (idx_t entry_idx = 0; entry_idx < entry_count; ++entry_idx) {
241 pmask[entry_idx] &= vvalidity.GetValidityEntry(entry_idx);
242 }
243 break;
244 }
245 case VectorType::CONSTANT_VECTOR:
246 // All or nothing
247 if (ConstantVector::IsNull(vector: v)) {
248 pvalidity.SetAllInvalid(count);
249 return count;
250 }
251 break;
252 default:
253 // One by one
254 for (idx_t i = 0; i < count; ++i) {
255 const auto idx = vdata.sel->get_index(idx: i);
256 if (!vvalidity.RowIsValidUnsafe(row_idx: idx)) {
257 pvalidity.SetInvalidUnsafe(i);
258 }
259 }
260 break;
261 }
262 }
263 return count - pvalidity.CountValid(count);
264 } else {
265 return count - VectorOperations::CountNotNull(input&: primary, count);
266 }
267}
268
269BufferHandle PhysicalRangeJoin::SliceSortedPayload(DataChunk &payload, GlobalSortState &state, const idx_t block_idx,
270 const SelectionVector &result, const idx_t result_count,
271 const idx_t left_cols) {
272 // There should only be one sorted block if they have been sorted
273 D_ASSERT(state.sorted_blocks.size() == 1);
274 SBScanState read_state(state.buffer_manager, state);
275 read_state.sb = state.sorted_blocks[0].get();
276 auto &sorted_data = *read_state.sb->payload_data;
277
278 read_state.SetIndices(block_idx_to: block_idx, entry_idx_to: 0);
279 read_state.PinData(sd&: sorted_data);
280 const auto data_ptr = read_state.DataPtr(sd&: sorted_data);
281 data_ptr_t heap_ptr = nullptr;
282
283 // Set up a batch of pointers to scan data from
284 Vector addresses(LogicalType::POINTER, result_count);
285 auto data_pointers = FlatVector::GetData<data_ptr_t>(vector&: addresses);
286
287 // Set up the data pointers for the values that are actually referenced
288 const idx_t &row_width = sorted_data.layout.GetRowWidth();
289
290 auto prev_idx = result.get_index(idx: 0);
291 SelectionVector gsel(result_count);
292 idx_t addr_count = 0;
293 gsel.set_index(idx: 0, loc: addr_count);
294 data_pointers[addr_count] = data_ptr + prev_idx * row_width;
295 for (idx_t i = 1; i < result_count; ++i) {
296 const auto row_idx = result.get_index(idx: i);
297 if (row_idx != prev_idx) {
298 data_pointers[++addr_count] = data_ptr + row_idx * row_width;
299 prev_idx = row_idx;
300 }
301 gsel.set_index(idx: i, loc: addr_count);
302 }
303 ++addr_count;
304
305 // Unswizzle the offsets back to pointers (if needed)
306 if (!sorted_data.layout.AllConstant() && state.external) {
307 heap_ptr = read_state.payload_heap_handle.Ptr();
308 }
309
310 // Deserialize the payload data
311 auto sel = FlatVector::IncrementalSelectionVector();
312 for (idx_t col_no = 0; col_no < sorted_data.layout.ColumnCount(); col_no++) {
313 auto &col = payload.data[left_cols + col_no];
314 RowOperations::Gather(rows&: addresses, row_sel: *sel, col, col_sel: *sel, count: addr_count, layout: sorted_data.layout, col_no, build_size: 0, heap_ptr);
315 col.Slice(sel: gsel, count: result_count);
316 }
317
318 return std::move(read_state.payload_heap_handle);
319}
320
321idx_t PhysicalRangeJoin::SelectJoinTail(const ExpressionType &condition, Vector &left, Vector &right,
322 const SelectionVector *sel, idx_t count, SelectionVector *true_sel) {
323 switch (condition) {
324 case ExpressionType::COMPARE_NOTEQUAL:
325 return VectorOperations::NotEquals(left, right, sel, count, true_sel, false_sel: nullptr);
326 case ExpressionType::COMPARE_LESSTHAN:
327 return VectorOperations::LessThan(left, right, sel, count, true_sel, false_sel: nullptr);
328 case ExpressionType::COMPARE_GREATERTHAN:
329 return VectorOperations::GreaterThan(left, right, sel, count, true_sel, false_sel: nullptr);
330 case ExpressionType::COMPARE_LESSTHANOREQUALTO:
331 return VectorOperations::LessThanEquals(left, right, sel, count, true_sel, false_sel: nullptr);
332 case ExpressionType::COMPARE_GREATERTHANOREQUALTO:
333 return VectorOperations::GreaterThanEquals(left, right, sel, count, true_sel, false_sel: nullptr);
334 case ExpressionType::COMPARE_DISTINCT_FROM:
335 return VectorOperations::DistinctFrom(left, right, sel, count, true_sel, false_sel: nullptr);
336 case ExpressionType::COMPARE_NOT_DISTINCT_FROM:
337 return VectorOperations::NotDistinctFrom(left, right, sel, count, true_sel, false_sel: nullptr);
338 case ExpressionType::COMPARE_EQUAL:
339 return VectorOperations::Equals(left, right, sel, count, true_sel, false_sel: nullptr);
340 default:
341 throw InternalException("Unsupported comparison type for PhysicalRangeJoin");
342 }
343
344 return count;
345}
346
347} // namespace duckdb
348