1#include "duckdb/function/aggregate/distributive_functions.hpp"
2#include "duckdb/common/exception.hpp"
3#include "duckdb/common/vector_operations/vector_operations.hpp"
4#include "duckdb/planner/expression.hpp"
5
6namespace duckdb {
7
8template <class T>
9struct FirstState {
10 T value;
11 bool is_set;
12 bool is_null;
13};
14
15struct FirstFunctionBase {
16 template <class STATE>
17 static void Initialize(STATE &state) {
18 state.is_set = false;
19 state.is_null = false;
20 }
21
22 static bool IgnoreNull() {
23 return false;
24 }
25};
26
27template <bool LAST, bool SKIP_NULLS>
28struct FirstFunction : public FirstFunctionBase {
29 template <class INPUT_TYPE, class STATE, class OP>
30 static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
31 if (LAST || !state.is_set) {
32 if (!unary_input.RowIsValid()) {
33 if (!SKIP_NULLS) {
34 state.is_set = true;
35 }
36 state.is_null = true;
37 } else {
38 state.is_set = true;
39 state.is_null = false;
40 state.value = input;
41 }
42 }
43 }
44
45 template <class INPUT_TYPE, class STATE, class OP>
46 static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
47 idx_t count) {
48 Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
49 }
50
51 template <class STATE, class OP>
52 static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
53 if (!target.is_set) {
54 target = source;
55 }
56 }
57
58 template <class T, class STATE>
59 static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
60 if (!state.is_set || state.is_null) {
61 finalize_data.ReturnNull();
62 } else {
63 target = state.value;
64 }
65 }
66};
67
68template <bool LAST, bool SKIP_NULLS>
69struct FirstFunctionString : public FirstFunctionBase {
70 template <class STATE>
71 static void SetValue(STATE &state, AggregateInputData &input_data, string_t value, bool is_null) {
72 if (LAST && state.is_set) {
73 Destroy(state, input_data);
74 }
75 if (is_null) {
76 if (!SKIP_NULLS) {
77 state.is_set = true;
78 state.is_null = true;
79 }
80 } else {
81 state.is_set = true;
82 state.is_null = false;
83 if (value.IsInlined()) {
84 state.value = value;
85 } else {
86 // non-inlined string, need to allocate space for it
87 auto len = value.GetSize();
88 auto ptr = new char[len];
89 memcpy(dest: ptr, src: value.GetData(), n: len);
90
91 state.value = string_t(ptr, len);
92 }
93 }
94 }
95
96 template <class INPUT_TYPE, class STATE, class OP>
97 static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) {
98 if (LAST || !state.is_set) {
99 SetValue(state, unary_input.input, input, !unary_input.RowIsValid());
100 }
101 }
102
103 template <class INPUT_TYPE, class STATE, class OP>
104 static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input,
105 idx_t count) {
106 Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input);
107 }
108
109 template <class STATE, class OP>
110 static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) {
111 if (source.is_set && (LAST || !target.is_set)) {
112 SetValue(target, input_data, source.value, source.is_null);
113 }
114 }
115
116 template <class T, class STATE>
117 static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
118 if (!state.is_set || state.is_null) {
119 finalize_data.ReturnNull();
120 } else {
121 target = StringVector::AddStringOrBlob(finalize_data.result, state.value);
122 }
123 }
124
125 template <class STATE>
126 static void Destroy(STATE &state, AggregateInputData &aggr_input_data) {
127 if (state.is_set && !state.is_null && !state.value.IsInlined()) {
128 delete[] state.value.GetData();
129 }
130 }
131};
132
133struct FirstStateVector {
134 Vector *value;
135};
136
137template <bool LAST, bool SKIP_NULLS>
138struct FirstVectorFunction {
139 template <class STATE>
140 static void Initialize(STATE &state) {
141 state.value = nullptr;
142 }
143
144 template <class STATE>
145 static void Destroy(STATE &state, AggregateInputData &aggr_input_data) {
146 if (state.value) {
147 delete state.value;
148 }
149 }
150 static bool IgnoreNull() {
151 return SKIP_NULLS;
152 }
153
154 template <class STATE>
155 static void SetValue(STATE &state, Vector &input, const idx_t idx) {
156 if (!state.value) {
157 state.value = new Vector(input.GetType());
158 state.value->SetVectorType(VectorType::CONSTANT_VECTOR);
159 }
160 sel_t selv = idx;
161 SelectionVector sel(&selv);
162 VectorOperations::Copy(input, *state.value, sel, 1, 0, 0);
163 }
164
165 static void Update(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, idx_t count) {
166 auto &input = inputs[0];
167 UnifiedVectorFormat idata;
168 input.ToUnifiedFormat(count, data&: idata);
169
170 UnifiedVectorFormat sdata;
171 state_vector.ToUnifiedFormat(count, data&: sdata);
172
173 auto states = UnifiedVectorFormat::GetData<FirstStateVector *>(format: sdata);
174 for (idx_t i = 0; i < count; i++) {
175 const auto idx = idata.sel->get_index(idx: i);
176 if (SKIP_NULLS && !idata.validity.RowIsValid(row_idx: idx)) {
177 continue;
178 }
179 auto &state = *states[sdata.sel->get_index(idx: i)];
180 if (LAST || !state.value) {
181 SetValue(state, input, i);
182 }
183 }
184 }
185
186 template <class STATE, class OP>
187 static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
188 if (source.value && (LAST || !target.value)) {
189 SetValue(target, *source.value, 0);
190 }
191 }
192
193 template <class STATE>
194 static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) {
195 if (!state.value) {
196 finalize_data.ReturnNull();
197 } else {
198 VectorOperations::Copy(*state.value, finalize_data.result, 1, 0, finalize_data.result_idx);
199 }
200 }
201
202 static unique_ptr<FunctionData> Bind(ClientContext &context, AggregateFunction &function,
203 vector<unique_ptr<Expression>> &arguments) {
204 function.arguments[0] = arguments[0]->return_type;
205 function.return_type = arguments[0]->return_type;
206 return nullptr;
207 }
208};
209
210template <class T, bool LAST, bool SKIP_NULLS>
211static AggregateFunction GetFirstAggregateTemplated(LogicalType type) {
212 return AggregateFunction::UnaryAggregate<FirstState<T>, T, T, FirstFunction<LAST, SKIP_NULLS>>(type, type);
213}
214
215template <bool LAST, bool SKIP_NULLS>
216static AggregateFunction GetFirstFunction(const LogicalType &type);
217
218template <bool LAST, bool SKIP_NULLS>
219AggregateFunction GetDecimalFirstFunction(const LogicalType &type) {
220 D_ASSERT(type.id() == LogicalTypeId::DECIMAL);
221 switch (type.InternalType()) {
222 case PhysicalType::INT16:
223 return GetFirstFunction<LAST, SKIP_NULLS>(LogicalType::SMALLINT);
224 case PhysicalType::INT32:
225 return GetFirstFunction<LAST, SKIP_NULLS>(LogicalType::INTEGER);
226 case PhysicalType::INT64:
227 return GetFirstFunction<LAST, SKIP_NULLS>(LogicalType::BIGINT);
228 default:
229 return GetFirstFunction<LAST, SKIP_NULLS>(LogicalType::HUGEINT);
230 }
231}
232
233template <bool LAST, bool SKIP_NULLS>
234static AggregateFunction GetFirstFunction(const LogicalType &type) {
235 switch (type.id()) {
236 case LogicalTypeId::BOOLEAN:
237 return GetFirstAggregateTemplated<int8_t, LAST, SKIP_NULLS>(type);
238 case LogicalTypeId::TINYINT:
239 return GetFirstAggregateTemplated<int8_t, LAST, SKIP_NULLS>(type);
240 case LogicalTypeId::SMALLINT:
241 return GetFirstAggregateTemplated<int16_t, LAST, SKIP_NULLS>(type);
242 case LogicalTypeId::INTEGER:
243 case LogicalTypeId::DATE:
244 return GetFirstAggregateTemplated<int32_t, LAST, SKIP_NULLS>(type);
245 case LogicalTypeId::BIGINT:
246 case LogicalTypeId::TIME:
247 case LogicalTypeId::TIMESTAMP:
248 case LogicalTypeId::TIME_TZ:
249 case LogicalTypeId::TIMESTAMP_TZ:
250 return GetFirstAggregateTemplated<int64_t, LAST, SKIP_NULLS>(type);
251 case LogicalTypeId::UTINYINT:
252 return GetFirstAggregateTemplated<uint8_t, LAST, SKIP_NULLS>(type);
253 case LogicalTypeId::USMALLINT:
254 return GetFirstAggregateTemplated<uint16_t, LAST, SKIP_NULLS>(type);
255 case LogicalTypeId::UINTEGER:
256 return GetFirstAggregateTemplated<uint32_t, LAST, SKIP_NULLS>(type);
257 case LogicalTypeId::UBIGINT:
258 return GetFirstAggregateTemplated<uint64_t, LAST, SKIP_NULLS>(type);
259 case LogicalTypeId::HUGEINT:
260 return GetFirstAggregateTemplated<hugeint_t, LAST, SKIP_NULLS>(type);
261 case LogicalTypeId::FLOAT:
262 return GetFirstAggregateTemplated<float, LAST, SKIP_NULLS>(type);
263 case LogicalTypeId::DOUBLE:
264 return GetFirstAggregateTemplated<double, LAST, SKIP_NULLS>(type);
265 case LogicalTypeId::INTERVAL:
266 return GetFirstAggregateTemplated<interval_t, LAST, SKIP_NULLS>(type);
267 case LogicalTypeId::VARCHAR:
268 case LogicalTypeId::BLOB:
269 return AggregateFunction::UnaryAggregateDestructor<FirstState<string_t>, string_t, string_t,
270 FirstFunctionString<LAST, SKIP_NULLS>>(type, type);
271 case LogicalTypeId::DECIMAL: {
272 type.Verify();
273 AggregateFunction function = GetDecimalFirstFunction<LAST, SKIP_NULLS>(type);
274 function.arguments[0] = type;
275 function.return_type = type;
276 // TODO set_key here?
277 return function;
278 }
279 default: {
280 using OP = FirstVectorFunction<LAST, SKIP_NULLS>;
281 return AggregateFunction({type}, type, AggregateFunction::StateSize<FirstStateVector>,
282 AggregateFunction::StateInitialize<FirstStateVector, OP>, OP::Update,
283 AggregateFunction::StateCombine<FirstStateVector, OP>,
284 AggregateFunction::StateVoidFinalize<FirstStateVector, OP>, nullptr, OP::Bind,
285 AggregateFunction::StateDestroy<FirstStateVector, OP>, nullptr, nullptr);
286 }
287 }
288}
289
290AggregateFunction FirstFun::GetFunction(const LogicalType &type) {
291 auto fun = GetFirstFunction<false, false>(type);
292 fun.name = "first";
293 return fun;
294}
295
296template <bool LAST, bool SKIP_NULLS>
297unique_ptr<FunctionData> BindDecimalFirst(ClientContext &context, AggregateFunction &function,
298 vector<unique_ptr<Expression>> &arguments) {
299 auto decimal_type = arguments[0]->return_type;
300 auto name = std::move(function.name);
301 function = GetFirstFunction<LAST, SKIP_NULLS>(decimal_type);
302 function.name = std::move(name);
303 function.return_type = decimal_type;
304 return nullptr;
305}
306
307template <bool LAST, bool SKIP_NULLS>
308static AggregateFunction GetFirstOperator(const LogicalType &type) {
309 if (type.id() == LogicalTypeId::DECIMAL) {
310 throw InternalException("FIXME: this shouldn't happen...");
311 }
312 return GetFirstFunction<LAST, SKIP_NULLS>(type);
313}
314
315template <bool LAST, bool SKIP_NULLS>
316unique_ptr<FunctionData> BindFirst(ClientContext &context, AggregateFunction &function,
317 vector<unique_ptr<Expression>> &arguments) {
318 auto input_type = arguments[0]->return_type;
319 auto name = std::move(function.name);
320 function = GetFirstOperator<LAST, SKIP_NULLS>(input_type);
321 function.name = std::move(name);
322 if (function.bind) {
323 return function.bind(context, function, arguments);
324 } else {
325 return nullptr;
326 }
327}
328
329template <bool LAST, bool SKIP_NULLS>
330static void AddFirstOperator(AggregateFunctionSet &set) {
331 set.AddFunction(function: AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr,
332 nullptr, nullptr, nullptr, BindDecimalFirst<LAST, SKIP_NULLS>));
333 set.AddFunction(function: AggregateFunction({LogicalType::ANY}, LogicalType::ANY, nullptr, nullptr, nullptr, nullptr, nullptr,
334 nullptr, BindFirst<LAST, SKIP_NULLS>));
335}
336
337void FirstFun::RegisterFunction(BuiltinFunctions &set) {
338 AggregateFunctionSet first("first");
339 AggregateFunctionSet last("last");
340 AggregateFunctionSet any_value("any_value");
341
342 AddFirstOperator<false, false>(set&: first);
343 AddFirstOperator<true, false>(set&: last);
344 AddFirstOperator<false, true>(set&: any_value);
345
346 set.AddFunction(set: first);
347 first.name = "arbitrary";
348 set.AddFunction(set: first);
349
350 set.AddFunction(set: last);
351
352 set.AddFunction(set: any_value);
353}
354
355} // namespace duckdb
356