| 1 | #include "duckdb/execution/operator/join/physical_range_join.hpp" |
| 2 | |
| 3 | #include "duckdb/common/fast_mem.hpp" |
| 4 | #include "duckdb/common/operator/comparison_operators.hpp" |
| 5 | #include "duckdb/common/row_operations/row_operations.hpp" |
| 6 | #include "duckdb/common/sort/comparators.hpp" |
| 7 | #include "duckdb/common/sort/sort.hpp" |
| 8 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
| 9 | #include "duckdb/execution/expression_executor.hpp" |
| 10 | #include "duckdb/main/client_context.hpp" |
| 11 | #include "duckdb/parallel/base_pipeline_event.hpp" |
| 12 | #include "duckdb/parallel/thread_context.hpp" |
| 13 | |
| 14 | #include <thread> |
| 15 | |
| 16 | namespace duckdb { |
| 17 | |
| 18 | PhysicalRangeJoin::LocalSortedTable::LocalSortedTable(ClientContext &context, const PhysicalRangeJoin &op, |
| 19 | const idx_t child) |
| 20 | : op(op), executor(context), has_null(0), count(0) { |
| 21 | // Initialize order clause expression executor and key DataChunk |
| 22 | vector<LogicalType> types; |
| 23 | for (const auto &cond : op.conditions) { |
| 24 | const auto &expr = child ? cond.right : cond.left; |
| 25 | executor.AddExpression(expr: *expr); |
| 26 | |
| 27 | types.push_back(x: expr->return_type); |
| 28 | } |
| 29 | auto &allocator = Allocator::Get(context); |
| 30 | keys.Initialize(allocator, types); |
| 31 | } |
| 32 | |
| 33 | void PhysicalRangeJoin::LocalSortedTable::Sink(DataChunk &input, GlobalSortState &global_sort_state) { |
| 34 | // Initialize local state (if necessary) |
| 35 | if (!local_sort_state.initialized) { |
| 36 | local_sort_state.Initialize(global_sort_state, buffer_manager_p&: global_sort_state.buffer_manager); |
| 37 | } |
| 38 | |
| 39 | // Obtain sorting columns |
| 40 | keys.Reset(); |
| 41 | executor.Execute(input, result&: keys); |
| 42 | |
| 43 | // Count the NULLs so we can exclude them later |
| 44 | has_null += MergeNulls(conditions: op.conditions); |
| 45 | count += keys.size(); |
| 46 | |
| 47 | // Only sort the primary key |
| 48 | DataChunk join_head; |
| 49 | join_head.data.emplace_back(args&: keys.data[0]); |
| 50 | join_head.SetCardinality(keys.size()); |
| 51 | |
| 52 | // Sink the data into the local sort state |
| 53 | local_sort_state.SinkChunk(sort&: join_head, payload&: input); |
| 54 | } |
| 55 | |
| 56 | PhysicalRangeJoin::GlobalSortedTable::GlobalSortedTable(ClientContext &context, const vector<BoundOrderByNode> &orders, |
| 57 | RowLayout &payload_layout) |
| 58 | : global_sort_state(BufferManager::GetBufferManager(context), orders, payload_layout), has_null(0), count(0), |
| 59 | memory_per_thread(0) { |
| 60 | D_ASSERT(orders.size() == 1); |
| 61 | |
| 62 | // Set external (can be forced with the PRAGMA) |
| 63 | auto &config = ClientConfig::GetConfig(context); |
| 64 | global_sort_state.external = config.force_external; |
| 65 | memory_per_thread = PhysicalRangeJoin::GetMaxThreadMemory(context); |
| 66 | } |
| 67 | |
| 68 | void PhysicalRangeJoin::GlobalSortedTable::Combine(LocalSortedTable <able) { |
| 69 | global_sort_state.AddLocalState(local_sort_state&: ltable.local_sort_state); |
| 70 | has_null += ltable.has_null; |
| 71 | count += ltable.count; |
| 72 | } |
| 73 | |
| 74 | void PhysicalRangeJoin::GlobalSortedTable::IntializeMatches() { |
| 75 | found_match = make_unsafe_uniq_array<bool>(n: Count()); |
| 76 | memset(s: found_match.get(), c: 0, n: sizeof(bool) * Count()); |
| 77 | } |
| 78 | |
| 79 | void PhysicalRangeJoin::GlobalSortedTable::Print() { |
| 80 | global_sort_state.Print(); |
| 81 | } |
| 82 | |
| 83 | class RangeJoinMergeTask : public ExecutorTask { |
| 84 | public: |
| 85 | using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; |
| 86 | |
| 87 | public: |
| 88 | RangeJoinMergeTask(shared_ptr<Event> event_p, ClientContext &context, GlobalSortedTable &table) |
| 89 | : ExecutorTask(context), event(std::move(event_p)), context(context), table(table) { |
| 90 | } |
| 91 | |
| 92 | TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { |
| 93 | // Initialize iejoin sorted and iterate until done |
| 94 | auto &global_sort_state = table.global_sort_state; |
| 95 | MergeSorter merge_sorter(global_sort_state, BufferManager::GetBufferManager(context)); |
| 96 | merge_sorter.PerformInMergeRound(); |
| 97 | event->FinishTask(); |
| 98 | |
| 99 | return TaskExecutionResult::TASK_FINISHED; |
| 100 | } |
| 101 | |
| 102 | private: |
| 103 | shared_ptr<Event> event; |
| 104 | ClientContext &context; |
| 105 | GlobalSortedTable &table; |
| 106 | }; |
| 107 | |
| 108 | class RangeJoinMergeEvent : public BasePipelineEvent { |
| 109 | public: |
| 110 | using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; |
| 111 | |
| 112 | public: |
| 113 | RangeJoinMergeEvent(GlobalSortedTable &table_p, Pipeline &pipeline_p) |
| 114 | : BasePipelineEvent(pipeline_p), table(table_p) { |
| 115 | } |
| 116 | |
| 117 | GlobalSortedTable &table; |
| 118 | |
| 119 | public: |
| 120 | void Schedule() override { |
| 121 | auto &context = pipeline->GetClientContext(); |
| 122 | |
| 123 | // Schedule tasks equal to the number of threads, which will each merge multiple partitions |
| 124 | auto &ts = TaskScheduler::GetScheduler(context); |
| 125 | idx_t num_threads = ts.NumberOfThreads(); |
| 126 | |
| 127 | vector<shared_ptr<Task>> iejoin_tasks; |
| 128 | for (idx_t tnum = 0; tnum < num_threads; tnum++) { |
| 129 | iejoin_tasks.push_back(x: make_uniq<RangeJoinMergeTask>(args: shared_from_this(), args&: context, args&: table)); |
| 130 | } |
| 131 | SetTasks(std::move(iejoin_tasks)); |
| 132 | } |
| 133 | |
| 134 | void FinishEvent() override { |
| 135 | auto &global_sort_state = table.global_sort_state; |
| 136 | |
| 137 | global_sort_state.CompleteMergeRound(keep_radix_data: true); |
| 138 | if (global_sort_state.sorted_blocks.size() > 1) { |
| 139 | // Multiple blocks remaining: Schedule the next round |
| 140 | table.ScheduleMergeTasks(pipeline&: *pipeline, event&: *this); |
| 141 | } |
| 142 | } |
| 143 | }; |
| 144 | |
| 145 | void PhysicalRangeJoin::GlobalSortedTable::ScheduleMergeTasks(Pipeline &pipeline, Event &event) { |
| 146 | // Initialize global sort state for a round of merging |
| 147 | global_sort_state.InitializeMergeRound(); |
| 148 | auto new_event = make_shared<RangeJoinMergeEvent>(args&: *this, args&: pipeline); |
| 149 | event.InsertEvent(replacement_event: std::move(new_event)); |
| 150 | } |
| 151 | |
| 152 | void PhysicalRangeJoin::GlobalSortedTable::Finalize(Pipeline &pipeline, Event &event) { |
| 153 | // Prepare for merge sort phase |
| 154 | global_sort_state.PrepareMergePhase(); |
| 155 | |
| 156 | // Start the merge phase or finish if a merge is not necessary |
| 157 | if (global_sort_state.sorted_blocks.size() > 1) { |
| 158 | ScheduleMergeTasks(pipeline, event); |
| 159 | } |
| 160 | } |
| 161 | |
| 162 | PhysicalRangeJoin::PhysicalRangeJoin(LogicalOperator &op, PhysicalOperatorType type, unique_ptr<PhysicalOperator> left, |
| 163 | unique_ptr<PhysicalOperator> right, vector<JoinCondition> cond, JoinType join_type, |
| 164 | idx_t estimated_cardinality) |
| 165 | : PhysicalComparisonJoin(op, type, std::move(cond), join_type, estimated_cardinality) { |
| 166 | // Reorder the conditions so that ranges are at the front. |
| 167 | // TODO: use stats to improve the choice? |
| 168 | // TODO: Prefer fixed length types? |
| 169 | if (conditions.size() > 1) { |
| 170 | auto conditions_p = std::move(conditions); |
| 171 | conditions.resize(new_size: conditions_p.size()); |
| 172 | idx_t range_position = 0; |
| 173 | idx_t other_position = conditions_p.size(); |
| 174 | for (idx_t i = 0; i < conditions_p.size(); ++i) { |
| 175 | switch (conditions_p[i].comparison) { |
| 176 | case ExpressionType::COMPARE_LESSTHAN: |
| 177 | case ExpressionType::COMPARE_LESSTHANOREQUALTO: |
| 178 | case ExpressionType::COMPARE_GREATERTHAN: |
| 179 | case ExpressionType::COMPARE_GREATERTHANOREQUALTO: |
| 180 | conditions[range_position++] = std::move(conditions_p[i]); |
| 181 | break; |
| 182 | default: |
| 183 | conditions[--other_position] = std::move(conditions_p[i]); |
| 184 | break; |
| 185 | } |
| 186 | } |
| 187 | } |
| 188 | |
| 189 | children.push_back(x: std::move(left)); |
| 190 | children.push_back(x: std::move(right)); |
| 191 | } |
| 192 | |
| 193 | idx_t PhysicalRangeJoin::LocalSortedTable::MergeNulls(const vector<JoinCondition> &conditions) { |
| 194 | // Merge the validity masks of the comparison keys into the primary |
| 195 | // Return the number of NULLs in the resulting chunk |
| 196 | D_ASSERT(keys.ColumnCount() > 0); |
| 197 | const auto count = keys.size(); |
| 198 | |
| 199 | size_t all_constant = 0; |
| 200 | for (auto &v : keys.data) { |
| 201 | if (v.GetVectorType() == VectorType::CONSTANT_VECTOR) { |
| 202 | ++all_constant; |
| 203 | } |
| 204 | } |
| 205 | |
| 206 | auto &primary = keys.data[0]; |
| 207 | if (all_constant == keys.data.size()) { |
| 208 | // Either all NULL or no NULLs |
| 209 | for (auto &v : keys.data) { |
| 210 | if (ConstantVector::IsNull(vector: v)) { |
| 211 | ConstantVector::SetNull(vector&: primary, is_null: true); |
| 212 | return count; |
| 213 | } |
| 214 | } |
| 215 | return 0; |
| 216 | } else if (keys.ColumnCount() > 1) { |
| 217 | // Flatten the primary, as it will need to merge arbitrary validity masks |
| 218 | primary.Flatten(count); |
| 219 | auto &pvalidity = FlatVector::Validity(vector&: primary); |
| 220 | D_ASSERT(keys.ColumnCount() == conditions.size()); |
| 221 | for (size_t c = 1; c < keys.data.size(); ++c) { |
| 222 | // Skip comparisons that accept NULLs |
| 223 | if (conditions[c].comparison == ExpressionType::COMPARE_DISTINCT_FROM) { |
| 224 | continue; |
| 225 | } |
| 226 | // ToUnifiedFormat the rest, as the sort code will do this anyway. |
| 227 | auto &v = keys.data[c]; |
| 228 | UnifiedVectorFormat vdata; |
| 229 | v.ToUnifiedFormat(count, data&: vdata); |
| 230 | auto &vvalidity = vdata.validity; |
| 231 | if (vvalidity.AllValid()) { |
| 232 | continue; |
| 233 | } |
| 234 | pvalidity.EnsureWritable(); |
| 235 | switch (v.GetVectorType()) { |
| 236 | case VectorType::FLAT_VECTOR: { |
| 237 | // Merge entire entries |
| 238 | auto pmask = pvalidity.GetData(); |
| 239 | const auto entry_count = pvalidity.EntryCount(count); |
| 240 | for (idx_t entry_idx = 0; entry_idx < entry_count; ++entry_idx) { |
| 241 | pmask[entry_idx] &= vvalidity.GetValidityEntry(entry_idx); |
| 242 | } |
| 243 | break; |
| 244 | } |
| 245 | case VectorType::CONSTANT_VECTOR: |
| 246 | // All or nothing |
| 247 | if (ConstantVector::IsNull(vector: v)) { |
| 248 | pvalidity.SetAllInvalid(count); |
| 249 | return count; |
| 250 | } |
| 251 | break; |
| 252 | default: |
| 253 | // One by one |
| 254 | for (idx_t i = 0; i < count; ++i) { |
| 255 | const auto idx = vdata.sel->get_index(idx: i); |
| 256 | if (!vvalidity.RowIsValidUnsafe(row_idx: idx)) { |
| 257 | pvalidity.SetInvalidUnsafe(i); |
| 258 | } |
| 259 | } |
| 260 | break; |
| 261 | } |
| 262 | } |
| 263 | return count - pvalidity.CountValid(count); |
| 264 | } else { |
| 265 | return count - VectorOperations::CountNotNull(input&: primary, count); |
| 266 | } |
| 267 | } |
| 268 | |
| 269 | BufferHandle PhysicalRangeJoin::SliceSortedPayload(DataChunk &payload, GlobalSortState &state, const idx_t block_idx, |
| 270 | const SelectionVector &result, const idx_t result_count, |
| 271 | const idx_t left_cols) { |
| 272 | // There should only be one sorted block if they have been sorted |
| 273 | D_ASSERT(state.sorted_blocks.size() == 1); |
| 274 | SBScanState read_state(state.buffer_manager, state); |
| 275 | read_state.sb = state.sorted_blocks[0].get(); |
| 276 | auto &sorted_data = *read_state.sb->payload_data; |
| 277 | |
| 278 | read_state.SetIndices(block_idx_to: block_idx, entry_idx_to: 0); |
| 279 | read_state.PinData(sd&: sorted_data); |
| 280 | const auto data_ptr = read_state.DataPtr(sd&: sorted_data); |
| 281 | data_ptr_t heap_ptr = nullptr; |
| 282 | |
| 283 | // Set up a batch of pointers to scan data from |
| 284 | Vector addresses(LogicalType::POINTER, result_count); |
| 285 | auto data_pointers = FlatVector::GetData<data_ptr_t>(vector&: addresses); |
| 286 | |
| 287 | // Set up the data pointers for the values that are actually referenced |
| 288 | const idx_t &row_width = sorted_data.layout.GetRowWidth(); |
| 289 | |
| 290 | auto prev_idx = result.get_index(idx: 0); |
| 291 | SelectionVector gsel(result_count); |
| 292 | idx_t addr_count = 0; |
| 293 | gsel.set_index(idx: 0, loc: addr_count); |
| 294 | data_pointers[addr_count] = data_ptr + prev_idx * row_width; |
| 295 | for (idx_t i = 1; i < result_count; ++i) { |
| 296 | const auto row_idx = result.get_index(idx: i); |
| 297 | if (row_idx != prev_idx) { |
| 298 | data_pointers[++addr_count] = data_ptr + row_idx * row_width; |
| 299 | prev_idx = row_idx; |
| 300 | } |
| 301 | gsel.set_index(idx: i, loc: addr_count); |
| 302 | } |
| 303 | ++addr_count; |
| 304 | |
| 305 | // Unswizzle the offsets back to pointers (if needed) |
| 306 | if (!sorted_data.layout.AllConstant() && state.external) { |
| 307 | heap_ptr = read_state.payload_heap_handle.Ptr(); |
| 308 | } |
| 309 | |
| 310 | // Deserialize the payload data |
| 311 | auto sel = FlatVector::IncrementalSelectionVector(); |
| 312 | for (idx_t col_no = 0; col_no < sorted_data.layout.ColumnCount(); col_no++) { |
| 313 | auto &col = payload.data[left_cols + col_no]; |
| 314 | RowOperations::Gather(rows&: addresses, row_sel: *sel, col, col_sel: *sel, count: addr_count, layout: sorted_data.layout, col_no, build_size: 0, heap_ptr); |
| 315 | col.Slice(sel: gsel, count: result_count); |
| 316 | } |
| 317 | |
| 318 | return std::move(read_state.payload_heap_handle); |
| 319 | } |
| 320 | |
| 321 | idx_t PhysicalRangeJoin::SelectJoinTail(const ExpressionType &condition, Vector &left, Vector &right, |
| 322 | const SelectionVector *sel, idx_t count, SelectionVector *true_sel) { |
| 323 | switch (condition) { |
| 324 | case ExpressionType::COMPARE_NOTEQUAL: |
| 325 | return VectorOperations::NotEquals(left, right, sel, count, true_sel, false_sel: nullptr); |
| 326 | case ExpressionType::COMPARE_LESSTHAN: |
| 327 | return VectorOperations::LessThan(left, right, sel, count, true_sel, false_sel: nullptr); |
| 328 | case ExpressionType::COMPARE_GREATERTHAN: |
| 329 | return VectorOperations::GreaterThan(left, right, sel, count, true_sel, false_sel: nullptr); |
| 330 | case ExpressionType::COMPARE_LESSTHANOREQUALTO: |
| 331 | return VectorOperations::LessThanEquals(left, right, sel, count, true_sel, false_sel: nullptr); |
| 332 | case ExpressionType::COMPARE_GREATERTHANOREQUALTO: |
| 333 | return VectorOperations::GreaterThanEquals(left, right, sel, count, true_sel, false_sel: nullptr); |
| 334 | case ExpressionType::COMPARE_DISTINCT_FROM: |
| 335 | return VectorOperations::DistinctFrom(left, right, sel, count, true_sel, false_sel: nullptr); |
| 336 | case ExpressionType::COMPARE_NOT_DISTINCT_FROM: |
| 337 | return VectorOperations::NotDistinctFrom(left, right, sel, count, true_sel, false_sel: nullptr); |
| 338 | case ExpressionType::COMPARE_EQUAL: |
| 339 | return VectorOperations::Equals(left, right, sel, count, true_sel, false_sel: nullptr); |
| 340 | default: |
| 341 | throw InternalException("Unsupported comparison type for PhysicalRangeJoin" ); |
| 342 | } |
| 343 | |
| 344 | return count; |
| 345 | } |
| 346 | |
| 347 | } // namespace duckdb |
| 348 | |