1#include "duckdb/common/types/data_chunk.hpp"
2#include "duckdb/function/scalar/nested_functions.hpp"
3#include "duckdb/planner/expression/bound_function_expression.hpp"
4#include "duckdb/planner/expression/bound_parameter_expression.hpp"
5#include "duckdb/planner/expression_binder.hpp"
6
7namespace duckdb {
8
9static void ListConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) {
10 D_ASSERT(args.ColumnCount() == 2);
11 auto count = args.size();
12
13 Vector &lhs = args.data[0];
14 Vector &rhs = args.data[1];
15 if (lhs.GetType().id() == LogicalTypeId::SQLNULL) {
16 result.Reference(other&: rhs);
17 return;
18 }
19 if (rhs.GetType().id() == LogicalTypeId::SQLNULL) {
20 result.Reference(other&: lhs);
21 return;
22 }
23
24 UnifiedVectorFormat lhs_data;
25 UnifiedVectorFormat rhs_data;
26 lhs.ToUnifiedFormat(count, data&: lhs_data);
27 rhs.ToUnifiedFormat(count, data&: rhs_data);
28 auto lhs_entries = UnifiedVectorFormat::GetData<list_entry_t>(format: lhs_data);
29 auto rhs_entries = UnifiedVectorFormat::GetData<list_entry_t>(format: rhs_data);
30
31 auto lhs_list_size = ListVector::GetListSize(vector: lhs);
32 auto rhs_list_size = ListVector::GetListSize(vector: rhs);
33 auto &lhs_child = ListVector::GetEntry(vector&: lhs);
34 auto &rhs_child = ListVector::GetEntry(vector&: rhs);
35 UnifiedVectorFormat lhs_child_data;
36 UnifiedVectorFormat rhs_child_data;
37 lhs_child.ToUnifiedFormat(count: lhs_list_size, data&: lhs_child_data);
38 rhs_child.ToUnifiedFormat(count: rhs_list_size, data&: rhs_child_data);
39
40 result.SetVectorType(VectorType::FLAT_VECTOR);
41 auto result_entries = FlatVector::GetData<list_entry_t>(vector&: result);
42 auto &result_validity = FlatVector::Validity(vector&: result);
43
44 idx_t offset = 0;
45 for (idx_t i = 0; i < count; i++) {
46 auto lhs_list_index = lhs_data.sel->get_index(idx: i);
47 auto rhs_list_index = rhs_data.sel->get_index(idx: i);
48 if (!lhs_data.validity.RowIsValid(row_idx: lhs_list_index) && !rhs_data.validity.RowIsValid(row_idx: rhs_list_index)) {
49 result_validity.SetInvalid(i);
50 continue;
51 }
52 result_entries[i].offset = offset;
53 result_entries[i].length = 0;
54 if (lhs_data.validity.RowIsValid(row_idx: lhs_list_index)) {
55 const auto &lhs_entry = lhs_entries[lhs_list_index];
56 result_entries[i].length += lhs_entry.length;
57 ListVector::Append(target&: result, source: lhs_child, sel: *lhs_child_data.sel, source_size: lhs_entry.offset + lhs_entry.length,
58 source_offset: lhs_entry.offset);
59 }
60 if (rhs_data.validity.RowIsValid(row_idx: rhs_list_index)) {
61 const auto &rhs_entry = rhs_entries[rhs_list_index];
62 result_entries[i].length += rhs_entry.length;
63 ListVector::Append(target&: result, source: rhs_child, sel: *rhs_child_data.sel, source_size: rhs_entry.offset + rhs_entry.length,
64 source_offset: rhs_entry.offset);
65 }
66 offset += result_entries[i].length;
67 }
68 D_ASSERT(ListVector::GetListSize(result) == offset);
69
70 if (lhs.GetVectorType() == VectorType::CONSTANT_VECTOR && rhs.GetVectorType() == VectorType::CONSTANT_VECTOR) {
71 result.SetVectorType(VectorType::CONSTANT_VECTOR);
72 }
73}
74
75static unique_ptr<FunctionData> ListConcatBind(ClientContext &context, ScalarFunction &bound_function,
76 vector<unique_ptr<Expression>> &arguments) {
77 D_ASSERT(bound_function.arguments.size() == 2);
78
79 auto &lhs = arguments[0]->return_type;
80 auto &rhs = arguments[1]->return_type;
81 if (lhs.id() == LogicalTypeId::UNKNOWN || rhs.id() == LogicalTypeId::UNKNOWN) {
82 throw ParameterNotResolvedException();
83 } else if (lhs.id() == LogicalTypeId::SQLNULL || rhs.id() == LogicalTypeId::SQLNULL) {
84 // we mimic postgres behaviour: list_concat(NULL, my_list) = my_list
85 auto return_type = rhs.id() == LogicalTypeId::SQLNULL ? lhs : rhs;
86 bound_function.arguments[0] = return_type;
87 bound_function.arguments[1] = return_type;
88 bound_function.return_type = return_type;
89 } else {
90 D_ASSERT(lhs.id() == LogicalTypeId::LIST);
91 D_ASSERT(rhs.id() == LogicalTypeId::LIST);
92
93 // Resolve list type
94 LogicalType child_type = LogicalType::SQLNULL;
95 for (const auto &argument : arguments) {
96 child_type = LogicalType::MaxLogicalType(left: child_type, right: ListType::GetChildType(type: argument->return_type));
97 }
98 auto list_type = LogicalType::LIST(child: child_type);
99
100 bound_function.arguments[0] = list_type;
101 bound_function.arguments[1] = list_type;
102 bound_function.return_type = list_type;
103 }
104 return make_uniq<VariableReturnBindData>(args&: bound_function.return_type);
105}
106
107static unique_ptr<BaseStatistics> ListConcatStats(ClientContext &context, FunctionStatisticsInput &input) {
108 auto &child_stats = input.child_stats;
109 D_ASSERT(child_stats.size() == 2);
110
111 auto &left_stats = child_stats[0];
112 auto &right_stats = child_stats[1];
113
114 auto stats = left_stats.ToUnique();
115 stats->Merge(other: right_stats);
116
117 return stats;
118}
119
120ScalarFunction ListConcatFun::GetFunction() {
121 // the arguments and return types are actually set in the binder function
122 auto fun = ScalarFunction({LogicalType::LIST(child: LogicalType::ANY), LogicalType::LIST(child: LogicalType::ANY)},
123 LogicalType::LIST(child: LogicalType::ANY), ListConcatFunction, ListConcatBind, nullptr,
124 ListConcatStats);
125 fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
126 return fun;
127}
128
129void ListConcatFun::RegisterFunction(BuiltinFunctions &set) {
130 set.AddFunction(names: {"list_concat", "list_cat", "array_concat", "array_cat"}, function: GetFunction());
131}
132
133} // namespace duckdb
134