1#include "duckdb/common/operator/comparison_operators.hpp"
2#include "duckdb/common/vector_operations/vector_operations.hpp"
3#include "duckdb/execution/nested_loop_join.hpp"
4
5using namespace duckdb;
6using namespace std;
7
8template <class T, class OP>
9static void mark_join_templated(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[]) {
10 VectorData left_data, right_data;
11 left.Orrify(lcount, left_data);
12 right.Orrify(rcount, right_data);
13
14 auto ldata = (T *)left_data.data;
15 auto rdata = (T *)right_data.data;
16 for (idx_t i = 0; i < lcount; i++) {
17 if (found_match[i]) {
18 continue;
19 }
20 auto lidx = left_data.sel->get_index(i);
21 if ((*left_data.nullmask)[lidx]) {
22 continue;
23 }
24 for (idx_t j = 0; j < rcount; j++) {
25 auto ridx = right_data.sel->get_index(j);
26 if ((*right_data.nullmask)[ridx]) {
27 continue;
28 }
29 if (OP::Operation(ldata[lidx], rdata[ridx])) {
30 found_match[i] = true;
31 break;
32 }
33 }
34 }
35}
36
37template <class OP>
38static void mark_join_operator(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[]) {
39 switch (left.type) {
40 case TypeId::BOOL:
41 case TypeId::INT8:
42 return mark_join_templated<int8_t, OP>(left, right, lcount, rcount, found_match);
43 case TypeId::INT16:
44 return mark_join_templated<int16_t, OP>(left, right, lcount, rcount, found_match);
45 case TypeId::INT32:
46 return mark_join_templated<int32_t, OP>(left, right, lcount, rcount, found_match);
47 case TypeId::INT64:
48 return mark_join_templated<int64_t, OP>(left, right, lcount, rcount, found_match);
49 case TypeId::FLOAT:
50 return mark_join_templated<float, OP>(left, right, lcount, rcount, found_match);
51 case TypeId::DOUBLE:
52 return mark_join_templated<double, OP>(left, right, lcount, rcount, found_match);
53 case TypeId::VARCHAR:
54 return mark_join_templated<string_t, OP>(left, right, lcount, rcount, found_match);
55 default:
56 throw NotImplementedException("Unimplemented type for join!");
57 }
58}
59
60static void mark_join(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[],
61 ExpressionType comparison_type) {
62 assert(left.type == right.type);
63 switch (comparison_type) {
64 case ExpressionType::COMPARE_EQUAL:
65 return mark_join_operator<duckdb::Equals>(left, right, lcount, rcount, found_match);
66 case ExpressionType::COMPARE_NOTEQUAL:
67 return mark_join_operator<duckdb::NotEquals>(left, right, lcount, rcount, found_match);
68 case ExpressionType::COMPARE_LESSTHAN:
69 return mark_join_operator<duckdb::LessThan>(left, right, lcount, rcount, found_match);
70 case ExpressionType::COMPARE_GREATERTHAN:
71 return mark_join_operator<duckdb::GreaterThan>(left, right, lcount, rcount, found_match);
72 case ExpressionType::COMPARE_LESSTHANOREQUALTO:
73 return mark_join_operator<duckdb::LessThanEquals>(left, right, lcount, rcount, found_match);
74 case ExpressionType::COMPARE_GREATERTHANOREQUALTO:
75 return mark_join_operator<duckdb::GreaterThanEquals>(left, right, lcount, rcount, found_match);
76 default:
77 throw NotImplementedException("Unimplemented comparison type for join!");
78 }
79}
80
81void NestedLoopJoinMark::Perform(DataChunk &left, ChunkCollection &right, bool found_match[],
82 vector<JoinCondition> &conditions) {
83 // initialize a new temporary selection vector for the left chunk
84 // loop over all chunks in the RHS
85 for (idx_t chunk_idx = 0; chunk_idx < right.chunks.size(); chunk_idx++) {
86 DataChunk &right_chunk = *right.chunks[chunk_idx];
87 for (idx_t i = 0; i < conditions.size(); i++) {
88 mark_join(left.data[i], right_chunk.data[i], left.size(), right_chunk.size(), found_match,
89 conditions[i].comparison);
90 }
91 }
92}
93