1 | #include "duckdb/function/aggregate/distributive_functions.hpp" |
2 | #include "duckdb/common/exception.hpp" |
3 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
4 | #include "duckdb/planner/expression.hpp" |
5 | |
6 | namespace duckdb { |
7 | |
8 | template <class T> |
9 | struct FirstState { |
10 | T value; |
11 | bool is_set; |
12 | bool is_null; |
13 | }; |
14 | |
15 | struct FirstFunctionBase { |
16 | template <class STATE> |
17 | static void Initialize(STATE &state) { |
18 | state.is_set = false; |
19 | state.is_null = false; |
20 | } |
21 | |
22 | static bool IgnoreNull() { |
23 | return false; |
24 | } |
25 | }; |
26 | |
27 | template <bool LAST, bool SKIP_NULLS> |
28 | struct FirstFunction : public FirstFunctionBase { |
29 | template <class INPUT_TYPE, class STATE, class OP> |
30 | static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { |
31 | if (LAST || !state.is_set) { |
32 | if (!unary_input.RowIsValid()) { |
33 | if (!SKIP_NULLS) { |
34 | state.is_set = true; |
35 | } |
36 | state.is_null = true; |
37 | } else { |
38 | state.is_set = true; |
39 | state.is_null = false; |
40 | state.value = input; |
41 | } |
42 | } |
43 | } |
44 | |
45 | template <class INPUT_TYPE, class STATE, class OP> |
46 | static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, |
47 | idx_t count) { |
48 | Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input); |
49 | } |
50 | |
51 | template <class STATE, class OP> |
52 | static void Combine(const STATE &source, STATE &target, AggregateInputData &) { |
53 | if (!target.is_set) { |
54 | target = source; |
55 | } |
56 | } |
57 | |
58 | template <class T, class STATE> |
59 | static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { |
60 | if (!state.is_set || state.is_null) { |
61 | finalize_data.ReturnNull(); |
62 | } else { |
63 | target = state.value; |
64 | } |
65 | } |
66 | }; |
67 | |
68 | template <bool LAST, bool SKIP_NULLS> |
69 | struct FirstFunctionString : public FirstFunctionBase { |
70 | template <class STATE> |
71 | static void SetValue(STATE &state, AggregateInputData &input_data, string_t value, bool is_null) { |
72 | if (LAST && state.is_set) { |
73 | Destroy(state, input_data); |
74 | } |
75 | if (is_null) { |
76 | if (!SKIP_NULLS) { |
77 | state.is_set = true; |
78 | state.is_null = true; |
79 | } |
80 | } else { |
81 | state.is_set = true; |
82 | state.is_null = false; |
83 | if (value.IsInlined()) { |
84 | state.value = value; |
85 | } else { |
86 | // non-inlined string, need to allocate space for it |
87 | auto len = value.GetSize(); |
88 | auto ptr = new char[len]; |
89 | memcpy(dest: ptr, src: value.GetData(), n: len); |
90 | |
91 | state.value = string_t(ptr, len); |
92 | } |
93 | } |
94 | } |
95 | |
96 | template <class INPUT_TYPE, class STATE, class OP> |
97 | static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { |
98 | if (LAST || !state.is_set) { |
99 | SetValue(state, unary_input.input, input, !unary_input.RowIsValid()); |
100 | } |
101 | } |
102 | |
103 | template <class INPUT_TYPE, class STATE, class OP> |
104 | static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, |
105 | idx_t count) { |
106 | Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input); |
107 | } |
108 | |
109 | template <class STATE, class OP> |
110 | static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { |
111 | if (source.is_set && (LAST || !target.is_set)) { |
112 | SetValue(target, input_data, source.value, source.is_null); |
113 | } |
114 | } |
115 | |
116 | template <class T, class STATE> |
117 | static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { |
118 | if (!state.is_set || state.is_null) { |
119 | finalize_data.ReturnNull(); |
120 | } else { |
121 | target = StringVector::AddStringOrBlob(finalize_data.result, state.value); |
122 | } |
123 | } |
124 | |
125 | template <class STATE> |
126 | static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { |
127 | if (state.is_set && !state.is_null && !state.value.IsInlined()) { |
128 | delete[] state.value.GetData(); |
129 | } |
130 | } |
131 | }; |
132 | |
133 | struct FirstStateVector { |
134 | Vector *value; |
135 | }; |
136 | |
137 | template <bool LAST, bool SKIP_NULLS> |
138 | struct FirstVectorFunction { |
139 | template <class STATE> |
140 | static void Initialize(STATE &state) { |
141 | state.value = nullptr; |
142 | } |
143 | |
144 | template <class STATE> |
145 | static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { |
146 | if (state.value) { |
147 | delete state.value; |
148 | } |
149 | } |
150 | static bool IgnoreNull() { |
151 | return SKIP_NULLS; |
152 | } |
153 | |
154 | template <class STATE> |
155 | static void SetValue(STATE &state, Vector &input, const idx_t idx) { |
156 | if (!state.value) { |
157 | state.value = new Vector(input.GetType()); |
158 | state.value->SetVectorType(VectorType::CONSTANT_VECTOR); |
159 | } |
160 | sel_t selv = idx; |
161 | SelectionVector sel(&selv); |
162 | VectorOperations::Copy(input, *state.value, sel, 1, 0, 0); |
163 | } |
164 | |
165 | static void Update(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, idx_t count) { |
166 | auto &input = inputs[0]; |
167 | UnifiedVectorFormat idata; |
168 | input.ToUnifiedFormat(count, data&: idata); |
169 | |
170 | UnifiedVectorFormat sdata; |
171 | state_vector.ToUnifiedFormat(count, data&: sdata); |
172 | |
173 | auto states = UnifiedVectorFormat::GetData<FirstStateVector *>(format: sdata); |
174 | for (idx_t i = 0; i < count; i++) { |
175 | const auto idx = idata.sel->get_index(idx: i); |
176 | if (SKIP_NULLS && !idata.validity.RowIsValid(row_idx: idx)) { |
177 | continue; |
178 | } |
179 | auto &state = *states[sdata.sel->get_index(idx: i)]; |
180 | if (LAST || !state.value) { |
181 | SetValue(state, input, i); |
182 | } |
183 | } |
184 | } |
185 | |
186 | template <class STATE, class OP> |
187 | static void Combine(const STATE &source, STATE &target, AggregateInputData &) { |
188 | if (source.value && (LAST || !target.value)) { |
189 | SetValue(target, *source.value, 0); |
190 | } |
191 | } |
192 | |
193 | template <class STATE> |
194 | static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { |
195 | if (!state.value) { |
196 | finalize_data.ReturnNull(); |
197 | } else { |
198 | VectorOperations::Copy(*state.value, finalize_data.result, 1, 0, finalize_data.result_idx); |
199 | } |
200 | } |
201 | |
202 | static unique_ptr<FunctionData> Bind(ClientContext &context, AggregateFunction &function, |
203 | vector<unique_ptr<Expression>> &arguments) { |
204 | function.arguments[0] = arguments[0]->return_type; |
205 | function.return_type = arguments[0]->return_type; |
206 | return nullptr; |
207 | } |
208 | }; |
209 | |
210 | template <class T, bool LAST, bool SKIP_NULLS> |
211 | static AggregateFunction GetFirstAggregateTemplated(LogicalType type) { |
212 | return AggregateFunction::UnaryAggregate<FirstState<T>, T, T, FirstFunction<LAST, SKIP_NULLS>>(type, type); |
213 | } |
214 | |
215 | template <bool LAST, bool SKIP_NULLS> |
216 | static AggregateFunction GetFirstFunction(const LogicalType &type); |
217 | |
218 | template <bool LAST, bool SKIP_NULLS> |
219 | AggregateFunction GetDecimalFirstFunction(const LogicalType &type) { |
220 | D_ASSERT(type.id() == LogicalTypeId::DECIMAL); |
221 | switch (type.InternalType()) { |
222 | case PhysicalType::INT16: |
223 | return GetFirstFunction<LAST, SKIP_NULLS>(LogicalType::SMALLINT); |
224 | case PhysicalType::INT32: |
225 | return GetFirstFunction<LAST, SKIP_NULLS>(LogicalType::INTEGER); |
226 | case PhysicalType::INT64: |
227 | return GetFirstFunction<LAST, SKIP_NULLS>(LogicalType::BIGINT); |
228 | default: |
229 | return GetFirstFunction<LAST, SKIP_NULLS>(LogicalType::HUGEINT); |
230 | } |
231 | } |
232 | |
233 | template <bool LAST, bool SKIP_NULLS> |
234 | static AggregateFunction GetFirstFunction(const LogicalType &type) { |
235 | switch (type.id()) { |
236 | case LogicalTypeId::BOOLEAN: |
237 | return GetFirstAggregateTemplated<int8_t, LAST, SKIP_NULLS>(type); |
238 | case LogicalTypeId::TINYINT: |
239 | return GetFirstAggregateTemplated<int8_t, LAST, SKIP_NULLS>(type); |
240 | case LogicalTypeId::SMALLINT: |
241 | return GetFirstAggregateTemplated<int16_t, LAST, SKIP_NULLS>(type); |
242 | case LogicalTypeId::INTEGER: |
243 | case LogicalTypeId::DATE: |
244 | return GetFirstAggregateTemplated<int32_t, LAST, SKIP_NULLS>(type); |
245 | case LogicalTypeId::BIGINT: |
246 | case LogicalTypeId::TIME: |
247 | case LogicalTypeId::TIMESTAMP: |
248 | case LogicalTypeId::TIME_TZ: |
249 | case LogicalTypeId::TIMESTAMP_TZ: |
250 | return GetFirstAggregateTemplated<int64_t, LAST, SKIP_NULLS>(type); |
251 | case LogicalTypeId::UTINYINT: |
252 | return GetFirstAggregateTemplated<uint8_t, LAST, SKIP_NULLS>(type); |
253 | case LogicalTypeId::USMALLINT: |
254 | return GetFirstAggregateTemplated<uint16_t, LAST, SKIP_NULLS>(type); |
255 | case LogicalTypeId::UINTEGER: |
256 | return GetFirstAggregateTemplated<uint32_t, LAST, SKIP_NULLS>(type); |
257 | case LogicalTypeId::UBIGINT: |
258 | return GetFirstAggregateTemplated<uint64_t, LAST, SKIP_NULLS>(type); |
259 | case LogicalTypeId::HUGEINT: |
260 | return GetFirstAggregateTemplated<hugeint_t, LAST, SKIP_NULLS>(type); |
261 | case LogicalTypeId::FLOAT: |
262 | return GetFirstAggregateTemplated<float, LAST, SKIP_NULLS>(type); |
263 | case LogicalTypeId::DOUBLE: |
264 | return GetFirstAggregateTemplated<double, LAST, SKIP_NULLS>(type); |
265 | case LogicalTypeId::INTERVAL: |
266 | return GetFirstAggregateTemplated<interval_t, LAST, SKIP_NULLS>(type); |
267 | case LogicalTypeId::VARCHAR: |
268 | case LogicalTypeId::BLOB: |
269 | return AggregateFunction::UnaryAggregateDestructor<FirstState<string_t>, string_t, string_t, |
270 | FirstFunctionString<LAST, SKIP_NULLS>>(type, type); |
271 | case LogicalTypeId::DECIMAL: { |
272 | type.Verify(); |
273 | AggregateFunction function = GetDecimalFirstFunction<LAST, SKIP_NULLS>(type); |
274 | function.arguments[0] = type; |
275 | function.return_type = type; |
276 | // TODO set_key here? |
277 | return function; |
278 | } |
279 | default: { |
280 | using OP = FirstVectorFunction<LAST, SKIP_NULLS>; |
281 | return AggregateFunction({type}, type, AggregateFunction::StateSize<FirstStateVector>, |
282 | AggregateFunction::StateInitialize<FirstStateVector, OP>, OP::Update, |
283 | AggregateFunction::StateCombine<FirstStateVector, OP>, |
284 | AggregateFunction::StateVoidFinalize<FirstStateVector, OP>, nullptr, OP::Bind, |
285 | AggregateFunction::StateDestroy<FirstStateVector, OP>, nullptr, nullptr); |
286 | } |
287 | } |
288 | } |
289 | |
290 | AggregateFunction FirstFun::GetFunction(const LogicalType &type) { |
291 | auto fun = GetFirstFunction<false, false>(type); |
292 | fun.name = "first" ; |
293 | return fun; |
294 | } |
295 | |
296 | template <bool LAST, bool SKIP_NULLS> |
297 | unique_ptr<FunctionData> BindDecimalFirst(ClientContext &context, AggregateFunction &function, |
298 | vector<unique_ptr<Expression>> &arguments) { |
299 | auto decimal_type = arguments[0]->return_type; |
300 | auto name = std::move(function.name); |
301 | function = GetFirstFunction<LAST, SKIP_NULLS>(decimal_type); |
302 | function.name = std::move(name); |
303 | function.return_type = decimal_type; |
304 | return nullptr; |
305 | } |
306 | |
307 | template <bool LAST, bool SKIP_NULLS> |
308 | static AggregateFunction GetFirstOperator(const LogicalType &type) { |
309 | if (type.id() == LogicalTypeId::DECIMAL) { |
310 | throw InternalException("FIXME: this shouldn't happen..." ); |
311 | } |
312 | return GetFirstFunction<LAST, SKIP_NULLS>(type); |
313 | } |
314 | |
315 | template <bool LAST, bool SKIP_NULLS> |
316 | unique_ptr<FunctionData> BindFirst(ClientContext &context, AggregateFunction &function, |
317 | vector<unique_ptr<Expression>> &arguments) { |
318 | auto input_type = arguments[0]->return_type; |
319 | auto name = std::move(function.name); |
320 | function = GetFirstOperator<LAST, SKIP_NULLS>(input_type); |
321 | function.name = std::move(name); |
322 | if (function.bind) { |
323 | return function.bind(context, function, arguments); |
324 | } else { |
325 | return nullptr; |
326 | } |
327 | } |
328 | |
329 | template <bool LAST, bool SKIP_NULLS> |
330 | static void AddFirstOperator(AggregateFunctionSet &set) { |
331 | set.AddFunction(function: AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, |
332 | nullptr, nullptr, nullptr, BindDecimalFirst<LAST, SKIP_NULLS>)); |
333 | set.AddFunction(function: AggregateFunction({LogicalType::ANY}, LogicalType::ANY, nullptr, nullptr, nullptr, nullptr, nullptr, |
334 | nullptr, BindFirst<LAST, SKIP_NULLS>)); |
335 | } |
336 | |
337 | void FirstFun::RegisterFunction(BuiltinFunctions &set) { |
338 | AggregateFunctionSet first("first" ); |
339 | AggregateFunctionSet last("last" ); |
340 | AggregateFunctionSet any_value("any_value" ); |
341 | |
342 | AddFirstOperator<false, false>(set&: first); |
343 | AddFirstOperator<true, false>(set&: last); |
344 | AddFirstOperator<false, true>(set&: any_value); |
345 | |
346 | set.AddFunction(set: first); |
347 | first.name = "arbitrary" ; |
348 | set.AddFunction(set: first); |
349 | |
350 | set.AddFunction(set: last); |
351 | |
352 | set.AddFunction(set: any_value); |
353 | } |
354 | |
355 | } // namespace duckdb |
356 | |