1#include "duckdb/function/aggregate/distributive_functions.hpp"
2#include "duckdb/common/exception.hpp"
3#include "duckdb/common/types/null_value.hpp"
4#include "duckdb/common/vector_operations/vector_operations.hpp"
5
6using namespace std;
7
8namespace duckdb {
9
10template <class T> struct FirstState {
11 bool is_set;
12 T value;
13};
14
15struct FirstFunctionBase {
16 template <class STATE> static void Initialize(STATE *state) {
17 state->is_set = false;
18 }
19
20 template <class STATE, class OP> static void Combine(STATE source, STATE *target) {
21 if (!target->is_set) {
22 *target = source;
23 }
24 }
25
26 static bool IgnoreNull() {
27 return false;
28 }
29};
30
31struct FirstFunction : public FirstFunctionBase {
32 template <class INPUT_TYPE, class STATE, class OP>
33 static void Operation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t idx) {
34 if (!state->is_set) {
35 state->is_set = true;
36 if (nullmask[idx]) {
37 state->value = NullValue<INPUT_TYPE>();
38 } else {
39 state->value = input[idx];
40 }
41 }
42 }
43
44 template <class INPUT_TYPE, class STATE, class OP>
45 static void ConstantOperation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t count) {
46 Operation<INPUT_TYPE, STATE, OP>(state, input, nullmask, 0);
47 }
48
49 template <class T, class STATE>
50 static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) {
51 if (!state->is_set || IsNullValue<T>(state->value)) {
52 nullmask[idx] = true;
53 } else {
54 target[idx] = state->value;
55 }
56 }
57};
58
59struct FirstFunctionString : public FirstFunctionBase {
60 template <class INPUT_TYPE, class STATE, class OP>
61 static void Operation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t idx) {
62 if (!state->is_set) {
63 state->is_set = true;
64 if (nullmask[idx]) {
65 state->value = NullValue<INPUT_TYPE>();
66 } else {
67 if (input[idx].IsInlined()) {
68 state->value = input[idx];
69 } else {
70 // non-inlined string, need to allocate space for it
71 auto len = input[idx].GetSize();
72 auto ptr = new char[len + 1];
73 memcpy(ptr, input[idx].GetData(), len + 1);
74
75 state->value = string_t(ptr, len);
76 }
77 }
78 }
79 }
80
81 template <class INPUT_TYPE, class STATE, class OP>
82 static void ConstantOperation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t count) {
83 Operation<INPUT_TYPE, STATE, OP>(state, input, nullmask, 0);
84 }
85
86 template <class T, class STATE>
87 static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) {
88 if (!state->is_set || IsNullValue<T>(state->value)) {
89 nullmask[idx] = true;
90 } else {
91 target[idx] = StringVector::AddString(result, state->value);
92 }
93 }
94
95 template <class STATE> static void Destroy(STATE *state) {
96 if (state->is_set && !state->value.IsInlined()) {
97 delete[] state->value.GetData();
98 }
99 }
100};
101
102template <class T> static AggregateFunction GetFirstAggregateTemplated(SQLType type) {
103 return AggregateFunction::UnaryAggregate<FirstState<T>, T, T, FirstFunction>(type, type);
104}
105
106AggregateFunction FirstFun::GetFunction(SQLType type) {
107 switch (type.id) {
108 case SQLTypeId::BOOLEAN:
109 return GetFirstAggregateTemplated<int8_t>(type);
110 case SQLTypeId::TINYINT:
111 return GetFirstAggregateTemplated<int8_t>(type);
112 case SQLTypeId::SMALLINT:
113 return GetFirstAggregateTemplated<int16_t>(type);
114 case SQLTypeId::INTEGER:
115 return GetFirstAggregateTemplated<int32_t>(type);
116 case SQLTypeId::BIGINT:
117 return GetFirstAggregateTemplated<int64_t>(type);
118 case SQLTypeId::FLOAT:
119 return GetFirstAggregateTemplated<float>(type);
120 case SQLTypeId::DOUBLE:
121 return GetFirstAggregateTemplated<double>(type);
122 case SQLTypeId::DECIMAL:
123 return GetFirstAggregateTemplated<double>(type);
124 case SQLTypeId::DATE:
125 return GetFirstAggregateTemplated<date_t>(type);
126 case SQLTypeId::TIMESTAMP:
127 return GetFirstAggregateTemplated<timestamp_t>(type);
128 case SQLTypeId::VARCHAR:
129 case SQLTypeId::BLOB:
130 return AggregateFunction::UnaryAggregateDestructor<FirstState<string_t>, string_t, string_t,
131 FirstFunctionString>(type, type);
132 default:
133 throw NotImplementedException("Unimplemented type for FIRST aggregate");
134 }
135}
136
137void FirstFun::RegisterFunction(BuiltinFunctions &set) {
138 AggregateFunctionSet first("first");
139 for (auto type : SQLType::ALL_TYPES) {
140 first.AddFunction(FirstFun::GetFunction(type));
141 }
142 set.AddFunction(first);
143}
144
145} // namespace duckdb
146