1#include "duckdb/function/aggregate/distributive_functions.hpp"
2#include "duckdb/common/exception.hpp"
3#include "duckdb/common/vector_operations/vector_operations.hpp"
4#include "duckdb/common/operator/comparison_operators.hpp"
5#include "duckdb/common/vector_operations/aggregate_executor.hpp"
6#include "duckdb/common/operator/aggregate_operators.hpp"
7#include "duckdb/common/types/null_value.hpp"
8
9using namespace std;
10
11namespace duckdb {
12
13template <class T> struct min_max_state_t {
14 T value;
15 bool isset;
16};
17
18template <class OP> static AggregateFunction GetUnaryAggregate(SQLType type) {
19 switch (type.id) {
20 case SQLTypeId::BOOLEAN:
21 return AggregateFunction::UnaryAggregate<min_max_state_t<int8_t>, int8_t, int8_t, OP>(type, type);
22 case SQLTypeId::TINYINT:
23 return AggregateFunction::UnaryAggregate<min_max_state_t<int8_t>, int8_t, int8_t, OP>(type, type);
24 case SQLTypeId::SMALLINT:
25 return AggregateFunction::UnaryAggregate<min_max_state_t<int16_t>, int16_t, int16_t, OP>(type, type);
26 case SQLTypeId::INTEGER:
27 return AggregateFunction::UnaryAggregate<min_max_state_t<int32_t>, int32_t, int32_t, OP>(type, type);
28 case SQLTypeId::BIGINT:
29 return AggregateFunction::UnaryAggregate<min_max_state_t<int64_t>, int64_t, int64_t, OP>(type, type);
30 case SQLTypeId::FLOAT:
31 return AggregateFunction::UnaryAggregate<min_max_state_t<float>, float, float, OP>(type, type);
32 case SQLTypeId::DOUBLE:
33 return AggregateFunction::UnaryAggregate<min_max_state_t<double>, double, double, OP>(type, type);
34 case SQLTypeId::DECIMAL:
35 return AggregateFunction::UnaryAggregate<min_max_state_t<double>, double, double, OP>(type, type);
36 case SQLTypeId::DATE:
37 return AggregateFunction::UnaryAggregate<min_max_state_t<date_t>, date_t, date_t, OP>(type, type);
38 case SQLTypeId::TIMESTAMP:
39 return AggregateFunction::UnaryAggregate<min_max_state_t<timestamp_t>, timestamp_t, timestamp_t, OP>(type,
40 type);
41 default:
42 throw NotImplementedException("Unimplemented type for unary aggregate");
43 }
44}
45
46struct MinMaxBase {
47 template <class STATE> static void Initialize(STATE *state) {
48 state->isset = false;
49 }
50
51 template <class INPUT_TYPE, class STATE, class OP>
52 static void ConstantOperation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t count) {
53 assert(!nullmask[0]);
54 if (!state->isset) {
55 state->isset = true;
56 OP::template Assign<INPUT_TYPE, STATE>(state, input[0]);
57 } else {
58 OP::template Execute<INPUT_TYPE, STATE>(state, input[0]);
59 }
60 }
61
62 template <class INPUT_TYPE, class STATE, class OP>
63 static void Operation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t idx) {
64 if (!state->isset) {
65 state->isset = true;
66 OP::template Assign<INPUT_TYPE, STATE>(state, input[idx]);
67 } else {
68 OP::template Execute<INPUT_TYPE, STATE>(state, input[idx]);
69 }
70 }
71
72 static bool IgnoreNull() {
73 return true;
74 }
75};
76
77struct NumericMinMaxBase : public MinMaxBase {
78 template <class INPUT_TYPE, class STATE> static void Assign(STATE *state, INPUT_TYPE input) {
79 state->value = input;
80 }
81
82 template <class T, class STATE>
83 static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) {
84 nullmask[idx] = !state->isset;
85 target[idx] = state->value;
86 }
87};
88
89struct MinOperation : public NumericMinMaxBase {
90 template <class INPUT_TYPE, class STATE> static void Execute(STATE *state, INPUT_TYPE input) {
91 if (LessThan::Operation<INPUT_TYPE>(input, state->value)) {
92 state->value = input;
93 }
94 }
95
96 template <class STATE, class OP> static void Combine(STATE source, STATE *target) {
97 if (!source.isset) {
98 // source is NULL, nothing to do
99 return;
100 }
101 if (!target->isset) {
102 // target is NULL, use source value directly
103 *target = source;
104 } else if (target->value > source.value) {
105 target->value = source.value;
106 }
107 }
108};
109
110struct MaxOperation : public NumericMinMaxBase {
111 template <class INPUT_TYPE, class STATE> static void Execute(STATE *state, INPUT_TYPE input) {
112 if (GreaterThan::Operation<INPUT_TYPE>(input, state->value)) {
113 state->value = input;
114 }
115 }
116
117 template <class STATE, class OP> static void Combine(STATE source, STATE *target) {
118 if (!source.isset) {
119 // source is NULL, nothing to do
120 return;
121 }
122 if (!target->isset) {
123 // target is NULL, use source value directly
124 *target = source;
125 } else if (target->value < source.value) {
126 target->value = source.value;
127 }
128 }
129};
130
131struct StringMinMaxBase : public MinMaxBase {
132 template <class STATE> static void Destroy(STATE *state) {
133 if (state->isset && !state->value.IsInlined()) {
134 delete[] state->value.GetData();
135 }
136 }
137
138 template <class INPUT_TYPE, class STATE> static void Assign(STATE *state, INPUT_TYPE input) {
139 if (input.IsInlined()) {
140 state->value = input;
141 } else {
142 // non-inlined string, need to allocate space for it
143 auto len = input.GetSize();
144 auto ptr = new char[len + 1];
145 memcpy(ptr, input.GetData(), len + 1);
146
147 state->value = string_t(ptr, len);
148 }
149 }
150
151 template <class T, class STATE>
152 static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) {
153 if (!state->isset) {
154 nullmask[idx] = true;
155 } else {
156 target[idx] = StringVector::AddString(result, state->value);
157 }
158 }
159
160 template <class STATE, class OP> static void Combine(STATE source, STATE *target) {
161 if (!source.isset) {
162 // source is NULL, nothing to do
163 return;
164 }
165 if (!target->isset) {
166 // target is NULL, use source value directly
167 *target = source;
168 } else {
169 OP::template Execute<string_t, STATE>(target, source.value);
170 }
171 }
172};
173
174struct MinOperationString : public StringMinMaxBase {
175 template <class INPUT_TYPE, class STATE> static void Execute(STATE *state, INPUT_TYPE input) {
176 if (LessThan::Operation<INPUT_TYPE>(input, state->value)) {
177 Assign(state, input);
178 }
179 }
180};
181
182struct MaxOperationString : public StringMinMaxBase {
183 template <class INPUT_TYPE, class STATE> static void Execute(STATE *state, INPUT_TYPE input) {
184 if (GreaterThan::Operation<INPUT_TYPE>(input, state->value)) {
185 Assign(state, input);
186 }
187 }
188};
189
190template <class OP, class OP_STRING> static void AddMinMaxOperator(AggregateFunctionSet &set) {
191 for (auto type : SQLType::ALL_TYPES) {
192 if (type.id == SQLTypeId::VARCHAR || type.id == SQLTypeId::BLOB) {
193 set.AddFunction(
194 AggregateFunction::UnaryAggregateDestructor<min_max_state_t<string_t>, string_t, string_t, OP_STRING>(
195 type.id, type.id));
196 } else {
197 set.AddFunction(GetUnaryAggregate<OP>(type));
198 }
199 }
200}
201
202void MinFun::RegisterFunction(BuiltinFunctions &set) {
203 AggregateFunctionSet min("min");
204 AddMinMaxOperator<MinOperation, MinOperationString>(min);
205 set.AddFunction(min);
206}
207
208void MaxFun::RegisterFunction(BuiltinFunctions &set) {
209 AggregateFunctionSet max("max");
210 AddMinMaxOperator<MaxOperation, MaxOperationString>(max);
211 set.AddFunction(max);
212}
213
214} // namespace duckdb
215