1//===----------------------------------------------------------------------===//
2// DuckDB
3//
4// duckdb/function/aggregate_function.hpp
5//
6//
7//===----------------------------------------------------------------------===//
8
9#pragma once
10
11#include "duckdb/function/aggregate_state.hpp"
12#include "duckdb/planner/bound_result_modifier.hpp"
13#include "duckdb/planner/expression.hpp"
14#include "duckdb/common/vector_operations/aggregate_executor.hpp"
15
16namespace duckdb {
17
18//! The type used for sizing hashed aggregate function states
19typedef idx_t (*aggregate_size_t)();
20//! The type used for initializing hashed aggregate function states
21typedef void (*aggregate_initialize_t)(data_ptr_t state);
22//! The type used for updating hashed aggregate functions
23typedef void (*aggregate_update_t)(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count,
24 Vector &state, idx_t count);
25//! The type used for combining hashed aggregate states
26typedef void (*aggregate_combine_t)(Vector &state, Vector &combined, AggregateInputData &aggr_input_data, idx_t count);
27//! The type used for finalizing hashed aggregate function payloads
28typedef void (*aggregate_finalize_t)(Vector &state, AggregateInputData &aggr_input_data, Vector &result, idx_t count,
29 idx_t offset);
30//! The type used for propagating statistics in aggregate functions (optional)
31typedef unique_ptr<BaseStatistics> (*aggregate_statistics_t)(ClientContext &context, BoundAggregateExpression &expr,
32 AggregateStatisticsInput &input);
33//! Binds the scalar function and creates the function data
34typedef unique_ptr<FunctionData> (*bind_aggregate_function_t)(ClientContext &context, AggregateFunction &function,
35 vector<unique_ptr<Expression>> &arguments);
36//! The type used for the aggregate destructor method. NOTE: this method is used in destructors and MAY NOT throw.
37typedef void (*aggregate_destructor_t)(Vector &state, AggregateInputData &aggr_input_data, idx_t count);
38
39//! The type used for updating simple (non-grouped) aggregate functions
40typedef void (*aggregate_simple_update_t)(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count,
41 data_ptr_t state, idx_t count);
42
43//! The type used for updating complex windowed aggregate functions (optional)
44typedef std::pair<idx_t, idx_t> FrameBounds;
45typedef void (*aggregate_window_t)(Vector inputs[], const ValidityMask &filter_mask,
46 AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state,
47 const FrameBounds &frame, const FrameBounds &prev, Vector &result, idx_t rid,
48 idx_t bias);
49
50typedef void (*aggregate_serialize_t)(FieldWriter &writer, const FunctionData *bind_data,
51 const AggregateFunction &function);
52typedef unique_ptr<FunctionData> (*aggregate_deserialize_t)(PlanDeserializationState &context, FieldReader &reader,
53 AggregateFunction &function);
54
55class AggregateFunction : public BaseScalarFunction {
56public:
57 AggregateFunction(const string &name, const vector<LogicalType> &arguments, const LogicalType &return_type,
58 aggregate_size_t state_size, aggregate_initialize_t initialize, aggregate_update_t update,
59 aggregate_combine_t combine, aggregate_finalize_t finalize,
60 FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING,
61 aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr,
62 aggregate_destructor_t destructor = nullptr, aggregate_statistics_t statistics = nullptr,
63 aggregate_window_t window = nullptr, aggregate_serialize_t serialize = nullptr,
64 aggregate_deserialize_t deserialize = nullptr)
65 : BaseScalarFunction(name, arguments, return_type, FunctionSideEffects::NO_SIDE_EFFECTS,
66 LogicalType(LogicalTypeId::INVALID), null_handling),
67 state_size(state_size), initialize(initialize), update(update), combine(combine), finalize(finalize),
68 simple_update(simple_update), window(window), bind(bind), destructor(destructor), statistics(statistics),
69 serialize(serialize), deserialize(deserialize), order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) {
70 }
71
72 AggregateFunction(const string &name, const vector<LogicalType> &arguments, const LogicalType &return_type,
73 aggregate_size_t state_size, aggregate_initialize_t initialize, aggregate_update_t update,
74 aggregate_combine_t combine, aggregate_finalize_t finalize,
75 aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr,
76 aggregate_destructor_t destructor = nullptr, aggregate_statistics_t statistics = nullptr,
77 aggregate_window_t window = nullptr, aggregate_serialize_t serialize = nullptr,
78 aggregate_deserialize_t deserialize = nullptr)
79 : BaseScalarFunction(name, arguments, return_type, FunctionSideEffects::NO_SIDE_EFFECTS,
80 LogicalType(LogicalTypeId::INVALID)),
81 state_size(state_size), initialize(initialize), update(update), combine(combine), finalize(finalize),
82 simple_update(simple_update), window(window), bind(bind), destructor(destructor), statistics(statistics),
83 serialize(serialize), deserialize(deserialize), order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) {
84 }
85
86 AggregateFunction(const vector<LogicalType> &arguments, const LogicalType &return_type, aggregate_size_t state_size,
87 aggregate_initialize_t initialize, aggregate_update_t update, aggregate_combine_t combine,
88 aggregate_finalize_t finalize,
89 FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING,
90 aggregate_simple_update_t simple_update = nullptr, bind_aggregate_function_t bind = nullptr,
91 aggregate_destructor_t destructor = nullptr, aggregate_statistics_t statistics = nullptr,
92 aggregate_window_t window = nullptr, aggregate_serialize_t serialize = nullptr,
93 aggregate_deserialize_t deserialize = nullptr)
94 : AggregateFunction(string(), arguments, return_type, state_size, initialize, update, combine, finalize,
95 null_handling, simple_update, bind, destructor, statistics, window, serialize,
96 deserialize) {
97 }
98
99 AggregateFunction(const vector<LogicalType> &arguments, const LogicalType &return_type, aggregate_size_t state_size,
100 aggregate_initialize_t initialize, aggregate_update_t update, aggregate_combine_t combine,
101 aggregate_finalize_t finalize, aggregate_simple_update_t simple_update = nullptr,
102 bind_aggregate_function_t bind = nullptr, aggregate_destructor_t destructor = nullptr,
103 aggregate_statistics_t statistics = nullptr, aggregate_window_t window = nullptr,
104 aggregate_serialize_t serialize = nullptr, aggregate_deserialize_t deserialize = nullptr)
105 : AggregateFunction(string(), arguments, return_type, state_size, initialize, update, combine, finalize,
106 FunctionNullHandling::DEFAULT_NULL_HANDLING, simple_update, bind, destructor, statistics,
107 window, serialize, deserialize) {
108 }
109 //! The hashed aggregate state sizing function
110 aggregate_size_t state_size;
111 //! The hashed aggregate state initialization function
112 aggregate_initialize_t initialize;
113 //! The hashed aggregate update state function
114 aggregate_update_t update;
115 //! The hashed aggregate combine states function
116 aggregate_combine_t combine;
117 //! The hashed aggregate finalization function
118 aggregate_finalize_t finalize;
119 //! The simple aggregate update function (may be null)
120 aggregate_simple_update_t simple_update;
121 //! The windowed aggregate frame update function (may be null)
122 aggregate_window_t window;
123
124 //! The bind function (may be null)
125 bind_aggregate_function_t bind;
126 //! The destructor method (may be null)
127 aggregate_destructor_t destructor;
128
129 //! The statistics propagation function (may be null)
130 aggregate_statistics_t statistics;
131
132 aggregate_serialize_t serialize;
133 aggregate_deserialize_t deserialize;
134 //! Whether or not the aggregate is order dependent
135 AggregateOrderDependent order_dependent;
136
137 bool operator==(const AggregateFunction &rhs) const {
138 return state_size == rhs.state_size && initialize == rhs.initialize && update == rhs.update &&
139 combine == rhs.combine && finalize == rhs.finalize && window == rhs.window;
140 }
141 bool operator!=(const AggregateFunction &rhs) const {
142 return !(*this == rhs);
143 }
144
145public:
146 template <class STATE, class RESULT_TYPE, class OP>
147 static AggregateFunction NullaryAggregate(LogicalType return_type) {
148 return AggregateFunction(
149 {}, return_type, AggregateFunction::StateSize<STATE>, AggregateFunction::StateInitialize<STATE, OP>,
150 AggregateFunction::NullaryScatterUpdate<STATE, OP>, AggregateFunction::StateCombine<STATE, OP>,
151 AggregateFunction::StateFinalize<STATE, RESULT_TYPE, OP>, AggregateFunction::NullaryUpdate<STATE, OP>);
152 }
153
154 template <class STATE, class INPUT_TYPE, class RESULT_TYPE, class OP>
155 static AggregateFunction
156 UnaryAggregate(const LogicalType &input_type, LogicalType return_type,
157 FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING) {
158 return AggregateFunction(
159 {input_type}, return_type, AggregateFunction::StateSize<STATE>,
160 AggregateFunction::StateInitialize<STATE, OP>, AggregateFunction::UnaryScatterUpdate<STATE, INPUT_TYPE, OP>,
161 AggregateFunction::StateCombine<STATE, OP>, AggregateFunction::StateFinalize<STATE, RESULT_TYPE, OP>,
162 null_handling, AggregateFunction::UnaryUpdate<STATE, INPUT_TYPE, OP>);
163 }
164
165 template <class STATE, class INPUT_TYPE, class RESULT_TYPE, class OP>
166 static AggregateFunction UnaryAggregateDestructor(LogicalType input_type, LogicalType return_type) {
167 auto aggregate = UnaryAggregate<STATE, INPUT_TYPE, RESULT_TYPE, OP>(input_type, return_type);
168 aggregate.destructor = AggregateFunction::StateDestroy<STATE, OP>;
169 return aggregate;
170 }
171
172 template <class STATE, class A_TYPE, class B_TYPE, class RESULT_TYPE, class OP>
173 static AggregateFunction BinaryAggregate(const LogicalType &a_type, const LogicalType &b_type,
174 LogicalType return_type) {
175 return AggregateFunction({a_type, b_type}, return_type, AggregateFunction::StateSize<STATE>,
176 AggregateFunction::StateInitialize<STATE, OP>,
177 AggregateFunction::BinaryScatterUpdate<STATE, A_TYPE, B_TYPE, OP>,
178 AggregateFunction::StateCombine<STATE, OP>,
179 AggregateFunction::StateFinalize<STATE, RESULT_TYPE, OP>,
180 AggregateFunction::BinaryUpdate<STATE, A_TYPE, B_TYPE, OP>);
181 }
182
183public:
184 template <class STATE>
185 static idx_t StateSize() {
186 return sizeof(STATE);
187 }
188
189 template <class STATE, class OP>
190 static void StateInitialize(data_ptr_t state) {
191 OP::Initialize(*reinterpret_cast<STATE *>(state));
192 }
193
194 template <class STATE, class OP>
195 static void NullaryScatterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count,
196 Vector &states, idx_t count) {
197 D_ASSERT(input_count == 0);
198 AggregateExecutor::NullaryScatter<STATE, OP>(states, aggr_input_data, count);
199 }
200
201 template <class STATE, class OP>
202 static void NullaryUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state,
203 idx_t count) {
204 D_ASSERT(input_count == 0);
205 AggregateExecutor::NullaryUpdate<STATE, OP>(state, aggr_input_data, count);
206 }
207
208 template <class STATE, class T, class OP>
209 static void UnaryScatterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count,
210 Vector &states, idx_t count) {
211 D_ASSERT(input_count == 1);
212 AggregateExecutor::UnaryScatter<STATE, T, OP>(inputs[0], states, aggr_input_data, count);
213 }
214
215 template <class STATE, class INPUT_TYPE, class OP>
216 static void UnaryUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state,
217 idx_t count) {
218 D_ASSERT(input_count == 1);
219 AggregateExecutor::UnaryUpdate<STATE, INPUT_TYPE, OP>(inputs[0], aggr_input_data, state, count);
220 }
221
222 template <class STATE, class INPUT_TYPE, class RESULT_TYPE, class OP>
223 static void UnaryWindow(Vector inputs[], const ValidityMask &filter_mask, AggregateInputData &aggr_input_data,
224 idx_t input_count, data_ptr_t state, const FrameBounds &frame, const FrameBounds &prev,
225 Vector &result, idx_t rid, idx_t bias) {
226 D_ASSERT(input_count == 1);
227 AggregateExecutor::UnaryWindow<STATE, INPUT_TYPE, RESULT_TYPE, OP>(inputs[0], filter_mask, aggr_input_data,
228 state, frame, prev, result, rid, bias);
229 }
230
231 template <class STATE, class A_TYPE, class B_TYPE, class OP>
232 static void BinaryScatterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count,
233 Vector &states, idx_t count) {
234 D_ASSERT(input_count == 2);
235 AggregateExecutor::BinaryScatter<STATE, A_TYPE, B_TYPE, OP>(aggr_input_data, inputs[0], inputs[1], states,
236 count);
237 }
238
239 template <class STATE, class A_TYPE, class B_TYPE, class OP>
240 static void BinaryUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state,
241 idx_t count) {
242 D_ASSERT(input_count == 2);
243 AggregateExecutor::BinaryUpdate<STATE, A_TYPE, B_TYPE, OP>(aggr_input_data, inputs[0], inputs[1], state, count);
244 }
245
246 template <class STATE, class OP>
247 static void StateCombine(Vector &source, Vector &target, AggregateInputData &aggr_input_data, idx_t count) {
248 AggregateExecutor::Combine<STATE, OP>(source, target, aggr_input_data, count);
249 }
250
251 template <class STATE, class RESULT_TYPE, class OP>
252 static void StateFinalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count,
253 idx_t offset) {
254 AggregateExecutor::Finalize<STATE, RESULT_TYPE, OP>(states, aggr_input_data, result, count, offset);
255 }
256
257 template <class STATE, class OP>
258 static void StateVoidFinalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count,
259 idx_t offset) {
260 AggregateExecutor::VoidFinalize<STATE, OP>(states, aggr_input_data, result, count, offset);
261 }
262
263 template <class STATE, class OP>
264 static void StateDestroy(Vector &states, AggregateInputData &aggr_input_data, idx_t count) {
265 AggregateExecutor::Destroy<STATE, OP>(states, aggr_input_data, count);
266 }
267};
268
269} // namespace duckdb
270