1 | #include "duckdb/function/aggregate/distributive_functions.hpp" |
2 | #include "duckdb/common/exception.hpp" |
3 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
4 | #include "duckdb/common/operator/comparison_operators.hpp" |
5 | #include "duckdb/common/vector_operations/aggregate_executor.hpp" |
6 | #include "duckdb/common/operator/aggregate_operators.hpp" |
7 | #include "duckdb/common/types/null_value.hpp" |
8 | |
9 | using namespace std; |
10 | |
11 | namespace duckdb { |
12 | |
13 | template <class T> struct min_max_state_t { |
14 | T value; |
15 | bool isset; |
16 | }; |
17 | |
18 | template <class OP> static AggregateFunction GetUnaryAggregate(SQLType type) { |
19 | switch (type.id) { |
20 | case SQLTypeId::BOOLEAN: |
21 | return AggregateFunction::UnaryAggregate<min_max_state_t<int8_t>, int8_t, int8_t, OP>(type, type); |
22 | case SQLTypeId::TINYINT: |
23 | return AggregateFunction::UnaryAggregate<min_max_state_t<int8_t>, int8_t, int8_t, OP>(type, type); |
24 | case SQLTypeId::SMALLINT: |
25 | return AggregateFunction::UnaryAggregate<min_max_state_t<int16_t>, int16_t, int16_t, OP>(type, type); |
26 | case SQLTypeId::INTEGER: |
27 | return AggregateFunction::UnaryAggregate<min_max_state_t<int32_t>, int32_t, int32_t, OP>(type, type); |
28 | case SQLTypeId::BIGINT: |
29 | return AggregateFunction::UnaryAggregate<min_max_state_t<int64_t>, int64_t, int64_t, OP>(type, type); |
30 | case SQLTypeId::FLOAT: |
31 | return AggregateFunction::UnaryAggregate<min_max_state_t<float>, float, float, OP>(type, type); |
32 | case SQLTypeId::DOUBLE: |
33 | return AggregateFunction::UnaryAggregate<min_max_state_t<double>, double, double, OP>(type, type); |
34 | case SQLTypeId::DECIMAL: |
35 | return AggregateFunction::UnaryAggregate<min_max_state_t<double>, double, double, OP>(type, type); |
36 | case SQLTypeId::DATE: |
37 | return AggregateFunction::UnaryAggregate<min_max_state_t<date_t>, date_t, date_t, OP>(type, type); |
38 | case SQLTypeId::TIMESTAMP: |
39 | return AggregateFunction::UnaryAggregate<min_max_state_t<timestamp_t>, timestamp_t, timestamp_t, OP>(type, |
40 | type); |
41 | default: |
42 | throw NotImplementedException("Unimplemented type for unary aggregate" ); |
43 | } |
44 | } |
45 | |
46 | struct MinMaxBase { |
47 | template <class STATE> static void Initialize(STATE *state) { |
48 | state->isset = false; |
49 | } |
50 | |
51 | template <class INPUT_TYPE, class STATE, class OP> |
52 | static void ConstantOperation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t count) { |
53 | assert(!nullmask[0]); |
54 | if (!state->isset) { |
55 | state->isset = true; |
56 | OP::template Assign<INPUT_TYPE, STATE>(state, input[0]); |
57 | } else { |
58 | OP::template Execute<INPUT_TYPE, STATE>(state, input[0]); |
59 | } |
60 | } |
61 | |
62 | template <class INPUT_TYPE, class STATE, class OP> |
63 | static void Operation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t idx) { |
64 | if (!state->isset) { |
65 | state->isset = true; |
66 | OP::template Assign<INPUT_TYPE, STATE>(state, input[idx]); |
67 | } else { |
68 | OP::template Execute<INPUT_TYPE, STATE>(state, input[idx]); |
69 | } |
70 | } |
71 | |
72 | static bool IgnoreNull() { |
73 | return true; |
74 | } |
75 | }; |
76 | |
77 | struct NumericMinMaxBase : public MinMaxBase { |
78 | template <class INPUT_TYPE, class STATE> static void Assign(STATE *state, INPUT_TYPE input) { |
79 | state->value = input; |
80 | } |
81 | |
82 | template <class T, class STATE> |
83 | static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) { |
84 | nullmask[idx] = !state->isset; |
85 | target[idx] = state->value; |
86 | } |
87 | }; |
88 | |
89 | struct MinOperation : public NumericMinMaxBase { |
90 | template <class INPUT_TYPE, class STATE> static void Execute(STATE *state, INPUT_TYPE input) { |
91 | if (LessThan::Operation<INPUT_TYPE>(input, state->value)) { |
92 | state->value = input; |
93 | } |
94 | } |
95 | |
96 | template <class STATE, class OP> static void Combine(STATE source, STATE *target) { |
97 | if (!source.isset) { |
98 | // source is NULL, nothing to do |
99 | return; |
100 | } |
101 | if (!target->isset) { |
102 | // target is NULL, use source value directly |
103 | *target = source; |
104 | } else if (target->value > source.value) { |
105 | target->value = source.value; |
106 | } |
107 | } |
108 | }; |
109 | |
110 | struct MaxOperation : public NumericMinMaxBase { |
111 | template <class INPUT_TYPE, class STATE> static void Execute(STATE *state, INPUT_TYPE input) { |
112 | if (GreaterThan::Operation<INPUT_TYPE>(input, state->value)) { |
113 | state->value = input; |
114 | } |
115 | } |
116 | |
117 | template <class STATE, class OP> static void Combine(STATE source, STATE *target) { |
118 | if (!source.isset) { |
119 | // source is NULL, nothing to do |
120 | return; |
121 | } |
122 | if (!target->isset) { |
123 | // target is NULL, use source value directly |
124 | *target = source; |
125 | } else if (target->value < source.value) { |
126 | target->value = source.value; |
127 | } |
128 | } |
129 | }; |
130 | |
131 | struct StringMinMaxBase : public MinMaxBase { |
132 | template <class STATE> static void Destroy(STATE *state) { |
133 | if (state->isset && !state->value.IsInlined()) { |
134 | delete[] state->value.GetData(); |
135 | } |
136 | } |
137 | |
138 | template <class INPUT_TYPE, class STATE> static void Assign(STATE *state, INPUT_TYPE input) { |
139 | if (input.IsInlined()) { |
140 | state->value = input; |
141 | } else { |
142 | // non-inlined string, need to allocate space for it |
143 | auto len = input.GetSize(); |
144 | auto ptr = new char[len + 1]; |
145 | memcpy(ptr, input.GetData(), len + 1); |
146 | |
147 | state->value = string_t(ptr, len); |
148 | } |
149 | } |
150 | |
151 | template <class T, class STATE> |
152 | static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) { |
153 | if (!state->isset) { |
154 | nullmask[idx] = true; |
155 | } else { |
156 | target[idx] = StringVector::AddString(result, state->value); |
157 | } |
158 | } |
159 | |
160 | template <class STATE, class OP> static void Combine(STATE source, STATE *target) { |
161 | if (!source.isset) { |
162 | // source is NULL, nothing to do |
163 | return; |
164 | } |
165 | if (!target->isset) { |
166 | // target is NULL, use source value directly |
167 | *target = source; |
168 | } else { |
169 | OP::template Execute<string_t, STATE>(target, source.value); |
170 | } |
171 | } |
172 | }; |
173 | |
174 | struct MinOperationString : public StringMinMaxBase { |
175 | template <class INPUT_TYPE, class STATE> static void Execute(STATE *state, INPUT_TYPE input) { |
176 | if (LessThan::Operation<INPUT_TYPE>(input, state->value)) { |
177 | Assign(state, input); |
178 | } |
179 | } |
180 | }; |
181 | |
182 | struct MaxOperationString : public StringMinMaxBase { |
183 | template <class INPUT_TYPE, class STATE> static void Execute(STATE *state, INPUT_TYPE input) { |
184 | if (GreaterThan::Operation<INPUT_TYPE>(input, state->value)) { |
185 | Assign(state, input); |
186 | } |
187 | } |
188 | }; |
189 | |
190 | template <class OP, class OP_STRING> static void AddMinMaxOperator(AggregateFunctionSet &set) { |
191 | for (auto type : SQLType::ALL_TYPES) { |
192 | if (type.id == SQLTypeId::VARCHAR || type.id == SQLTypeId::BLOB) { |
193 | set.AddFunction( |
194 | AggregateFunction::UnaryAggregateDestructor<min_max_state_t<string_t>, string_t, string_t, OP_STRING>( |
195 | type.id, type.id)); |
196 | } else { |
197 | set.AddFunction(GetUnaryAggregate<OP>(type)); |
198 | } |
199 | } |
200 | } |
201 | |
202 | void MinFun::RegisterFunction(BuiltinFunctions &set) { |
203 | AggregateFunctionSet min("min" ); |
204 | AddMinMaxOperator<MinOperation, MinOperationString>(min); |
205 | set.AddFunction(min); |
206 | } |
207 | |
208 | void MaxFun::RegisterFunction(BuiltinFunctions &set) { |
209 | AggregateFunctionSet max("max" ); |
210 | AddMinMaxOperator<MaxOperation, MaxOperationString>(max); |
211 | set.AddFunction(max); |
212 | } |
213 | |
214 | } // namespace duckdb |
215 | |