1 | #include "duckdb/common/types/data_chunk.hpp" |
2 | #include "duckdb/function/scalar/nested_functions.hpp" |
3 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
4 | #include "duckdb/planner/expression/bound_parameter_expression.hpp" |
5 | #include "duckdb/planner/expression_binder.hpp" |
6 | |
7 | namespace duckdb { |
8 | |
9 | static void ListConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { |
10 | D_ASSERT(args.ColumnCount() == 2); |
11 | auto count = args.size(); |
12 | |
13 | Vector &lhs = args.data[0]; |
14 | Vector &rhs = args.data[1]; |
15 | if (lhs.GetType().id() == LogicalTypeId::SQLNULL) { |
16 | result.Reference(other&: rhs); |
17 | return; |
18 | } |
19 | if (rhs.GetType().id() == LogicalTypeId::SQLNULL) { |
20 | result.Reference(other&: lhs); |
21 | return; |
22 | } |
23 | |
24 | UnifiedVectorFormat lhs_data; |
25 | UnifiedVectorFormat rhs_data; |
26 | lhs.ToUnifiedFormat(count, data&: lhs_data); |
27 | rhs.ToUnifiedFormat(count, data&: rhs_data); |
28 | auto lhs_entries = UnifiedVectorFormat::GetData<list_entry_t>(format: lhs_data); |
29 | auto rhs_entries = UnifiedVectorFormat::GetData<list_entry_t>(format: rhs_data); |
30 | |
31 | auto lhs_list_size = ListVector::GetListSize(vector: lhs); |
32 | auto rhs_list_size = ListVector::GetListSize(vector: rhs); |
33 | auto &lhs_child = ListVector::GetEntry(vector&: lhs); |
34 | auto &rhs_child = ListVector::GetEntry(vector&: rhs); |
35 | UnifiedVectorFormat lhs_child_data; |
36 | UnifiedVectorFormat rhs_child_data; |
37 | lhs_child.ToUnifiedFormat(count: lhs_list_size, data&: lhs_child_data); |
38 | rhs_child.ToUnifiedFormat(count: rhs_list_size, data&: rhs_child_data); |
39 | |
40 | result.SetVectorType(VectorType::FLAT_VECTOR); |
41 | auto result_entries = FlatVector::GetData<list_entry_t>(vector&: result); |
42 | auto &result_validity = FlatVector::Validity(vector&: result); |
43 | |
44 | idx_t offset = 0; |
45 | for (idx_t i = 0; i < count; i++) { |
46 | auto lhs_list_index = lhs_data.sel->get_index(idx: i); |
47 | auto rhs_list_index = rhs_data.sel->get_index(idx: i); |
48 | if (!lhs_data.validity.RowIsValid(row_idx: lhs_list_index) && !rhs_data.validity.RowIsValid(row_idx: rhs_list_index)) { |
49 | result_validity.SetInvalid(i); |
50 | continue; |
51 | } |
52 | result_entries[i].offset = offset; |
53 | result_entries[i].length = 0; |
54 | if (lhs_data.validity.RowIsValid(row_idx: lhs_list_index)) { |
55 | const auto &lhs_entry = lhs_entries[lhs_list_index]; |
56 | result_entries[i].length += lhs_entry.length; |
57 | ListVector::Append(target&: result, source: lhs_child, sel: *lhs_child_data.sel, source_size: lhs_entry.offset + lhs_entry.length, |
58 | source_offset: lhs_entry.offset); |
59 | } |
60 | if (rhs_data.validity.RowIsValid(row_idx: rhs_list_index)) { |
61 | const auto &rhs_entry = rhs_entries[rhs_list_index]; |
62 | result_entries[i].length += rhs_entry.length; |
63 | ListVector::Append(target&: result, source: rhs_child, sel: *rhs_child_data.sel, source_size: rhs_entry.offset + rhs_entry.length, |
64 | source_offset: rhs_entry.offset); |
65 | } |
66 | offset += result_entries[i].length; |
67 | } |
68 | D_ASSERT(ListVector::GetListSize(result) == offset); |
69 | |
70 | if (lhs.GetVectorType() == VectorType::CONSTANT_VECTOR && rhs.GetVectorType() == VectorType::CONSTANT_VECTOR) { |
71 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
72 | } |
73 | } |
74 | |
75 | static unique_ptr<FunctionData> ListConcatBind(ClientContext &context, ScalarFunction &bound_function, |
76 | vector<unique_ptr<Expression>> &arguments) { |
77 | D_ASSERT(bound_function.arguments.size() == 2); |
78 | |
79 | auto &lhs = arguments[0]->return_type; |
80 | auto &rhs = arguments[1]->return_type; |
81 | if (lhs.id() == LogicalTypeId::UNKNOWN || rhs.id() == LogicalTypeId::UNKNOWN) { |
82 | throw ParameterNotResolvedException(); |
83 | } else if (lhs.id() == LogicalTypeId::SQLNULL || rhs.id() == LogicalTypeId::SQLNULL) { |
84 | // we mimic postgres behaviour: list_concat(NULL, my_list) = my_list |
85 | auto return_type = rhs.id() == LogicalTypeId::SQLNULL ? lhs : rhs; |
86 | bound_function.arguments[0] = return_type; |
87 | bound_function.arguments[1] = return_type; |
88 | bound_function.return_type = return_type; |
89 | } else { |
90 | D_ASSERT(lhs.id() == LogicalTypeId::LIST); |
91 | D_ASSERT(rhs.id() == LogicalTypeId::LIST); |
92 | |
93 | // Resolve list type |
94 | LogicalType child_type = LogicalType::SQLNULL; |
95 | for (const auto &argument : arguments) { |
96 | child_type = LogicalType::MaxLogicalType(left: child_type, right: ListType::GetChildType(type: argument->return_type)); |
97 | } |
98 | auto list_type = LogicalType::LIST(child: child_type); |
99 | |
100 | bound_function.arguments[0] = list_type; |
101 | bound_function.arguments[1] = list_type; |
102 | bound_function.return_type = list_type; |
103 | } |
104 | return make_uniq<VariableReturnBindData>(args&: bound_function.return_type); |
105 | } |
106 | |
107 | static unique_ptr<BaseStatistics> ListConcatStats(ClientContext &context, FunctionStatisticsInput &input) { |
108 | auto &child_stats = input.child_stats; |
109 | D_ASSERT(child_stats.size() == 2); |
110 | |
111 | auto &left_stats = child_stats[0]; |
112 | auto &right_stats = child_stats[1]; |
113 | |
114 | auto stats = left_stats.ToUnique(); |
115 | stats->Merge(other: right_stats); |
116 | |
117 | return stats; |
118 | } |
119 | |
120 | ScalarFunction ListConcatFun::GetFunction() { |
121 | // the arguments and return types are actually set in the binder function |
122 | auto fun = ScalarFunction({LogicalType::LIST(child: LogicalType::ANY), LogicalType::LIST(child: LogicalType::ANY)}, |
123 | LogicalType::LIST(child: LogicalType::ANY), ListConcatFunction, ListConcatBind, nullptr, |
124 | ListConcatStats); |
125 | fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; |
126 | return fun; |
127 | } |
128 | |
129 | void ListConcatFun::RegisterFunction(BuiltinFunctions &set) { |
130 | set.AddFunction(names: {"list_concat" , "list_cat" , "array_concat" , "array_cat" }, function: GetFunction()); |
131 | } |
132 | |
133 | } // namespace duckdb |
134 | |