1 | #include "duckdb/function/aggregate/distributive_functions.hpp" |
2 | #include "duckdb/common/exception.hpp" |
3 | #include "duckdb/common/types/null_value.hpp" |
4 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
5 | |
6 | using namespace std; |
7 | |
8 | namespace duckdb { |
9 | |
10 | template <class T> struct FirstState { |
11 | bool is_set; |
12 | T value; |
13 | }; |
14 | |
15 | struct FirstFunctionBase { |
16 | template <class STATE> static void Initialize(STATE *state) { |
17 | state->is_set = false; |
18 | } |
19 | |
20 | template <class STATE, class OP> static void Combine(STATE source, STATE *target) { |
21 | if (!target->is_set) { |
22 | *target = source; |
23 | } |
24 | } |
25 | |
26 | static bool IgnoreNull() { |
27 | return false; |
28 | } |
29 | }; |
30 | |
31 | struct FirstFunction : public FirstFunctionBase { |
32 | template <class INPUT_TYPE, class STATE, class OP> |
33 | static void Operation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t idx) { |
34 | if (!state->is_set) { |
35 | state->is_set = true; |
36 | if (nullmask[idx]) { |
37 | state->value = NullValue<INPUT_TYPE>(); |
38 | } else { |
39 | state->value = input[idx]; |
40 | } |
41 | } |
42 | } |
43 | |
44 | template <class INPUT_TYPE, class STATE, class OP> |
45 | static void ConstantOperation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t count) { |
46 | Operation<INPUT_TYPE, STATE, OP>(state, input, nullmask, 0); |
47 | } |
48 | |
49 | template <class T, class STATE> |
50 | static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) { |
51 | if (!state->is_set || IsNullValue<T>(state->value)) { |
52 | nullmask[idx] = true; |
53 | } else { |
54 | target[idx] = state->value; |
55 | } |
56 | } |
57 | }; |
58 | |
59 | struct FirstFunctionString : public FirstFunctionBase { |
60 | template <class INPUT_TYPE, class STATE, class OP> |
61 | static void Operation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t idx) { |
62 | if (!state->is_set) { |
63 | state->is_set = true; |
64 | if (nullmask[idx]) { |
65 | state->value = NullValue<INPUT_TYPE>(); |
66 | } else { |
67 | if (input[idx].IsInlined()) { |
68 | state->value = input[idx]; |
69 | } else { |
70 | // non-inlined string, need to allocate space for it |
71 | auto len = input[idx].GetSize(); |
72 | auto ptr = new char[len + 1]; |
73 | memcpy(ptr, input[idx].GetData(), len + 1); |
74 | |
75 | state->value = string_t(ptr, len); |
76 | } |
77 | } |
78 | } |
79 | } |
80 | |
81 | template <class INPUT_TYPE, class STATE, class OP> |
82 | static void ConstantOperation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t count) { |
83 | Operation<INPUT_TYPE, STATE, OP>(state, input, nullmask, 0); |
84 | } |
85 | |
86 | template <class T, class STATE> |
87 | static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) { |
88 | if (!state->is_set || IsNullValue<T>(state->value)) { |
89 | nullmask[idx] = true; |
90 | } else { |
91 | target[idx] = StringVector::AddString(result, state->value); |
92 | } |
93 | } |
94 | |
95 | template <class STATE> static void Destroy(STATE *state) { |
96 | if (state->is_set && !state->value.IsInlined()) { |
97 | delete[] state->value.GetData(); |
98 | } |
99 | } |
100 | }; |
101 | |
102 | template <class T> static AggregateFunction GetFirstAggregateTemplated(SQLType type) { |
103 | return AggregateFunction::UnaryAggregate<FirstState<T>, T, T, FirstFunction>(type, type); |
104 | } |
105 | |
106 | AggregateFunction FirstFun::GetFunction(SQLType type) { |
107 | switch (type.id) { |
108 | case SQLTypeId::BOOLEAN: |
109 | return GetFirstAggregateTemplated<int8_t>(type); |
110 | case SQLTypeId::TINYINT: |
111 | return GetFirstAggregateTemplated<int8_t>(type); |
112 | case SQLTypeId::SMALLINT: |
113 | return GetFirstAggregateTemplated<int16_t>(type); |
114 | case SQLTypeId::INTEGER: |
115 | return GetFirstAggregateTemplated<int32_t>(type); |
116 | case SQLTypeId::BIGINT: |
117 | return GetFirstAggregateTemplated<int64_t>(type); |
118 | case SQLTypeId::FLOAT: |
119 | return GetFirstAggregateTemplated<float>(type); |
120 | case SQLTypeId::DOUBLE: |
121 | return GetFirstAggregateTemplated<double>(type); |
122 | case SQLTypeId::DECIMAL: |
123 | return GetFirstAggregateTemplated<double>(type); |
124 | case SQLTypeId::DATE: |
125 | return GetFirstAggregateTemplated<date_t>(type); |
126 | case SQLTypeId::TIMESTAMP: |
127 | return GetFirstAggregateTemplated<timestamp_t>(type); |
128 | case SQLTypeId::VARCHAR: |
129 | case SQLTypeId::BLOB: |
130 | return AggregateFunction::UnaryAggregateDestructor<FirstState<string_t>, string_t, string_t, |
131 | FirstFunctionString>(type, type); |
132 | default: |
133 | throw NotImplementedException("Unimplemented type for FIRST aggregate" ); |
134 | } |
135 | } |
136 | |
137 | void FirstFun::RegisterFunction(BuiltinFunctions &set) { |
138 | AggregateFunctionSet first("first" ); |
139 | for (auto type : SQLType::ALL_TYPES) { |
140 | first.AddFunction(FirstFun::GetFunction(type)); |
141 | } |
142 | set.AddFunction(first); |
143 | } |
144 | |
145 | } // namespace duckdb |
146 | |