1#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp"
2#include "duckdb/function/scalar/generic_functions.hpp"
3#include "duckdb/main/client_context.hpp"
4#include "duckdb/main/database.hpp"
5#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
6#include "duckdb/planner/expression/bound_constant_expression.hpp"
7#include "duckdb/planner/expression/bound_function_expression.hpp"
8#include "duckdb/function/function_binder.hpp"
9
10namespace duckdb {
11
12// aggregate state export
13struct ExportAggregateBindData : public FunctionData {
14 AggregateFunction aggr;
15 idx_t state_size;
16
17 explicit ExportAggregateBindData(AggregateFunction aggr_p, idx_t state_size_p)
18 : aggr(std::move(aggr_p)), state_size(state_size_p) {
19 }
20
21 unique_ptr<FunctionData> Copy() const override {
22 return make_uniq<ExportAggregateBindData>(args: aggr, args: state_size);
23 }
24
25 bool Equals(const FunctionData &other_p) const override {
26 auto &other = other_p.Cast<ExportAggregateBindData>();
27 return aggr == other.aggr && state_size == other.state_size;
28 }
29
30 static ExportAggregateBindData &GetFrom(ExpressionState &state) {
31 auto &func_expr = state.expr.Cast<BoundFunctionExpression>();
32 return func_expr.bind_info->Cast<ExportAggregateBindData>();
33 }
34};
35
36struct CombineState : public FunctionLocalState {
37 idx_t state_size;
38
39 unsafe_unique_array<data_t> state_buffer0, state_buffer1;
40 Vector state_vector0, state_vector1;
41
42 explicit CombineState(idx_t state_size_p)
43 : state_size(state_size_p), state_buffer0(make_unsafe_uniq_array<data_t>(n: state_size_p)),
44 state_buffer1(make_unsafe_uniq_array<data_t>(n: state_size_p)),
45 state_vector0(Value::POINTER(value: CastPointerToValue(src: state_buffer0.get()))),
46 state_vector1(Value::POINTER(value: CastPointerToValue(src: state_buffer1.get()))) {
47 }
48};
49
50static unique_ptr<FunctionLocalState> InitCombineState(ExpressionState &state, const BoundFunctionExpression &expr,
51 FunctionData *bind_data_p) {
52 auto &bind_data = bind_data_p->Cast<ExportAggregateBindData>();
53 return make_uniq<CombineState>(args&: bind_data.state_size);
54}
55
56struct FinalizeState : public FunctionLocalState {
57 idx_t state_size;
58 unsafe_unique_array<data_t> state_buffer;
59 Vector addresses;
60
61 explicit FinalizeState(idx_t state_size_p)
62 : state_size(state_size_p),
63 state_buffer(make_unsafe_uniq_array<data_t>(STANDARD_VECTOR_SIZE * AlignValue(n: state_size_p))),
64 addresses(LogicalType::POINTER) {
65 }
66};
67
68static unique_ptr<FunctionLocalState> InitFinalizeState(ExpressionState &state, const BoundFunctionExpression &expr,
69 FunctionData *bind_data_p) {
70 auto &bind_data = bind_data_p->Cast<ExportAggregateBindData>();
71 return make_uniq<FinalizeState>(args&: bind_data.state_size);
72}
73
74static void AggregateStateFinalize(DataChunk &input, ExpressionState &state_p, Vector &result) {
75 auto &bind_data = ExportAggregateBindData::GetFrom(state&: state_p);
76 auto &local_state = ExecuteFunctionState::GetFunctionState(state&: state_p)->Cast<FinalizeState>();
77
78 D_ASSERT(bind_data.state_size == bind_data.aggr.state_size());
79 D_ASSERT(input.data.size() == 1);
80 D_ASSERT(input.data[0].GetType().id() == LogicalTypeId::AGGREGATE_STATE);
81 auto aligned_state_size = AlignValue(n: bind_data.state_size);
82
83 auto state_vec_ptr = FlatVector::GetData<data_ptr_t>(vector&: local_state.addresses);
84
85 UnifiedVectorFormat state_data;
86 input.data[0].ToUnifiedFormat(count: input.size(), data&: state_data);
87 for (idx_t i = 0; i < input.size(); i++) {
88 auto state_idx = state_data.sel->get_index(idx: i);
89 auto state_entry = UnifiedVectorFormat::GetData<string_t>(format: state_data) + state_idx;
90 auto target_ptr = char_ptr_cast(src: local_state.state_buffer.get()) + aligned_state_size * i;
91
92 if (state_data.validity.RowIsValid(row_idx: state_idx)) {
93 D_ASSERT(state_entry->GetSize() == bind_data.state_size);
94 memcpy(dest: (void *)target_ptr, src: state_entry->GetData(), n: bind_data.state_size);
95 } else {
96 // create a dummy state because finalize does not understand NULLs in its input
97 // we put the NULL back in explicitly below
98 bind_data.aggr.initialize(data_ptr_cast(src: target_ptr));
99 }
100 state_vec_ptr[i] = data_ptr_cast(src: target_ptr);
101 }
102
103 AggregateInputData aggr_input_data(nullptr, Allocator::DefaultAllocator());
104 bind_data.aggr.finalize(local_state.addresses, aggr_input_data, result, input.size(), 0);
105
106 for (idx_t i = 0; i < input.size(); i++) {
107 auto state_idx = state_data.sel->get_index(idx: i);
108 if (!state_data.validity.RowIsValid(row_idx: state_idx)) {
109 FlatVector::SetNull(vector&: result, idx: i, is_null: true);
110 }
111 }
112}
113
114static void AggregateStateCombine(DataChunk &input, ExpressionState &state_p, Vector &result) {
115 auto &bind_data = ExportAggregateBindData::GetFrom(state&: state_p);
116 auto &local_state = ExecuteFunctionState::GetFunctionState(state&: state_p)->Cast<CombineState>();
117
118 D_ASSERT(bind_data.state_size == bind_data.aggr.state_size());
119
120 D_ASSERT(input.data.size() == 2);
121 D_ASSERT(input.data[0].GetType().id() == LogicalTypeId::AGGREGATE_STATE);
122 D_ASSERT(input.data[0].GetType() == result.GetType());
123
124 if (input.data[0].GetType().InternalType() != input.data[1].GetType().InternalType()) {
125 throw IOException("Aggregate state combine type mismatch, expect %s, got %s",
126 input.data[0].GetType().ToString(), input.data[1].GetType().ToString());
127 }
128
129 UnifiedVectorFormat state0_data, state1_data;
130 input.data[0].ToUnifiedFormat(count: input.size(), data&: state0_data);
131 input.data[1].ToUnifiedFormat(count: input.size(), data&: state1_data);
132
133 auto result_ptr = FlatVector::GetData<string_t>(vector&: result);
134
135 for (idx_t i = 0; i < input.size(); i++) {
136 auto state0_idx = state0_data.sel->get_index(idx: i);
137 auto state1_idx = state1_data.sel->get_index(idx: i);
138
139 auto &state0 = UnifiedVectorFormat::GetData<string_t>(format: state0_data)[state0_idx];
140 auto &state1 = UnifiedVectorFormat::GetData<string_t>(format: state1_data)[state1_idx];
141
142 // if both are NULL, we return NULL. If either of them is not, the result is that one
143 if (!state0_data.validity.RowIsValid(row_idx: state0_idx) && !state1_data.validity.RowIsValid(row_idx: state1_idx)) {
144 FlatVector::SetNull(vector&: result, idx: i, is_null: true);
145 continue;
146 }
147 if (state0_data.validity.RowIsValid(row_idx: state0_idx) && !state1_data.validity.RowIsValid(row_idx: state1_idx)) {
148 result_ptr[i] =
149 StringVector::AddStringOrBlob(vector&: result, data: const_char_ptr_cast(src: state0.GetData()), len: bind_data.state_size);
150 continue;
151 }
152 if (!state0_data.validity.RowIsValid(row_idx: state0_idx) && state1_data.validity.RowIsValid(row_idx: state1_idx)) {
153 result_ptr[i] =
154 StringVector::AddStringOrBlob(vector&: result, data: const_char_ptr_cast(src: state1.GetData()), len: bind_data.state_size);
155 continue;
156 }
157
158 // we actually have to combine
159 if (state0.GetSize() != bind_data.state_size || state1.GetSize() != bind_data.state_size) {
160 throw IOException("Aggregate state size mismatch, expect %llu, got %llu and %llu", bind_data.state_size,
161 state0.GetSize(), state1.GetSize());
162 }
163
164 memcpy(dest: local_state.state_buffer0.get(), src: state0.GetData(), n: bind_data.state_size);
165 memcpy(dest: local_state.state_buffer1.get(), src: state1.GetData(), n: bind_data.state_size);
166
167 AggregateInputData aggr_input_data(nullptr, Allocator::DefaultAllocator());
168 bind_data.aggr.combine(local_state.state_vector0, local_state.state_vector1, aggr_input_data, 1);
169
170 result_ptr[i] = StringVector::AddStringOrBlob(vector&: result, data: const_char_ptr_cast(src: local_state.state_buffer1.get()),
171 len: bind_data.state_size);
172 }
173}
174
175static unique_ptr<FunctionData> BindAggregateState(ClientContext &context, ScalarFunction &bound_function,
176 vector<unique_ptr<Expression>> &arguments) {
177
178 // grab the aggregate type and bind the aggregate again
179
180 // the aggregate name and types are in the logical type of the aggregate state, make sure its sane
181 auto &arg_return_type = arguments[0]->return_type;
182 for (auto &arg_type : bound_function.arguments) {
183 arg_type = arg_return_type;
184 }
185
186 if (arg_return_type.id() != LogicalTypeId::AGGREGATE_STATE) {
187 throw BinderException("Can only FINALIZE aggregate state, not %s", arg_return_type.ToString());
188 }
189 // combine
190 if (arguments.size() == 2 && arguments[0]->return_type != arguments[1]->return_type &&
191 arguments[1]->return_type.id() != LogicalTypeId::BLOB) {
192 throw BinderException("Cannot COMBINE aggregate states from different functions, %s <> %s",
193 arguments[0]->return_type.ToString(), arguments[1]->return_type.ToString());
194 }
195
196 // following error states are only reachable when someone messes up creating the state_type
197 // which is impossible from SQL
198
199 auto state_type = AggregateStateType::GetStateType(type: arg_return_type);
200
201 // now we can look up the function in the catalog again and bind it
202 auto &func = Catalog::GetSystemCatalog(context).GetEntry(context, type: CatalogType::SCALAR_FUNCTION_ENTRY,
203 DEFAULT_SCHEMA, name: state_type.function_name);
204 if (func.type != CatalogType::AGGREGATE_FUNCTION_ENTRY) {
205 throw InternalException("Could not find aggregate %s", state_type.function_name);
206 }
207 auto &aggr = func.Cast<AggregateFunctionCatalogEntry>();
208
209 string error;
210
211 FunctionBinder function_binder(context);
212 idx_t best_function =
213 function_binder.BindFunction(name: aggr.name, functions&: aggr.functions, arguments: state_type.bound_argument_types, error);
214 if (best_function == DConstants::INVALID_INDEX) {
215 throw InternalException("Could not re-bind exported aggregate %s: %s", state_type.function_name, error);
216 }
217 auto bound_aggr = aggr.functions.GetFunctionByOffset(offset: best_function);
218 if (bound_aggr.bind) {
219 // FIXME: this is really hacky
220 // but the aggregate state export needs a rework around how it handles more complex aggregates anyway
221 vector<unique_ptr<Expression>> args;
222 args.reserve(n: state_type.bound_argument_types.size());
223 for (auto &arg_type : state_type.bound_argument_types) {
224 args.push_back(x: make_uniq<BoundConstantExpression>(args: Value(arg_type)));
225 }
226 auto bind_info = bound_aggr.bind(context, bound_aggr, args);
227 if (bind_info) {
228 throw BinderException("Aggregate function with bind info not supported yet in aggregate state export");
229 }
230 }
231
232 if (bound_aggr.return_type != state_type.return_type || bound_aggr.arguments != state_type.bound_argument_types) {
233 throw InternalException("Type mismatch for exported aggregate %s", state_type.function_name);
234 }
235
236 if (bound_function.name == "finalize") {
237 bound_function.return_type = bound_aggr.return_type;
238 } else {
239 D_ASSERT(bound_function.name == "combine");
240 bound_function.return_type = arg_return_type;
241 }
242
243 return make_uniq<ExportAggregateBindData>(args&: bound_aggr, args: bound_aggr.state_size());
244}
245
246static void ExportAggregateFinalize(Vector &state, AggregateInputData &aggr_input_data, Vector &result, idx_t count,
247 idx_t offset) {
248 D_ASSERT(offset == 0);
249 auto &bind_data = aggr_input_data.bind_data->Cast<ExportAggregateFunctionBindData>();
250 auto state_size = bind_data.aggregate->function.state_size();
251 auto blob_ptr = FlatVector::GetData<string_t>(vector&: result);
252 auto addresses_ptr = FlatVector::GetData<data_ptr_t>(vector&: state);
253 for (idx_t row_idx = 0; row_idx < count; row_idx++) {
254 auto data_ptr = addresses_ptr[row_idx];
255 blob_ptr[row_idx] = StringVector::AddStringOrBlob(vector&: result, data: const_char_ptr_cast(src: data_ptr), len: state_size);
256 }
257}
258
259ExportAggregateFunctionBindData::ExportAggregateFunctionBindData(unique_ptr<Expression> aggregate_p) {
260 D_ASSERT(aggregate_p->type == ExpressionType::BOUND_AGGREGATE);
261 aggregate = unique_ptr_cast<Expression, BoundAggregateExpression>(src: std::move(aggregate_p));
262}
263
264unique_ptr<FunctionData> ExportAggregateFunctionBindData::Copy() const {
265 return make_uniq<ExportAggregateFunctionBindData>(args: aggregate->Copy());
266}
267
268bool ExportAggregateFunctionBindData::Equals(const FunctionData &other_p) const {
269 auto &other = other_p.Cast<ExportAggregateFunctionBindData>();
270 return aggregate->Equals(other: *other.aggregate);
271}
272
273static void ExportStateAggregateSerialize(FieldWriter &writer, const FunctionData *bind_data_p,
274 const AggregateFunction &function) {
275 throw NotImplementedException("FIXME: export state serialize");
276}
277static unique_ptr<FunctionData> ExportStateAggregateDeserialize(PlanDeserializationState &state, FieldReader &reader,
278 AggregateFunction &bound_function) {
279 throw NotImplementedException("FIXME: export state deserialize");
280}
281
282static void ExportStateScalarSerialize(FieldWriter &writer, const FunctionData *bind_data_p,
283 const ScalarFunction &function) {
284 throw NotImplementedException("FIXME: export state serialize");
285}
286static unique_ptr<FunctionData> ExportStateScalarDeserialize(PlanDeserializationState &state, FieldReader &reader,
287 ScalarFunction &bound_function) {
288 throw NotImplementedException("FIXME: export state deserialize");
289}
290
291unique_ptr<BoundAggregateExpression>
292ExportAggregateFunction::Bind(unique_ptr<BoundAggregateExpression> child_aggregate) {
293 auto &bound_function = child_aggregate->function;
294 if (!bound_function.combine) {
295 throw BinderException("Cannot use EXPORT_STATE for non-combinable function %s", bound_function.name);
296 }
297 if (bound_function.bind) {
298 throw BinderException("Cannot use EXPORT_STATE on aggregate functions with custom binders");
299 }
300 if (bound_function.destructor) {
301 throw BinderException("Cannot use EXPORT_STATE on aggregate functions with custom destructors");
302 }
303 // this should be required
304 D_ASSERT(bound_function.state_size);
305 D_ASSERT(bound_function.finalize);
306
307 D_ASSERT(child_aggregate->function.return_type.id() != LogicalTypeId::INVALID);
308#ifdef DEBUG
309 for (auto &arg_type : child_aggregate->function.arguments) {
310 D_ASSERT(arg_type.id() != LogicalTypeId::INVALID);
311 }
312#endif
313 auto export_bind_data = make_uniq<ExportAggregateFunctionBindData>(args: child_aggregate->Copy());
314 aggregate_state_t state_type(child_aggregate->function.name, child_aggregate->function.return_type,
315 child_aggregate->function.arguments);
316 auto return_type = LogicalType::AGGREGATE_STATE(state_type: std::move(state_type));
317
318 auto export_function =
319 AggregateFunction("aggregate_state_export_" + bound_function.name, bound_function.arguments, return_type,
320 bound_function.state_size, bound_function.initialize, bound_function.update,
321 bound_function.combine, ExportAggregateFinalize, bound_function.simple_update,
322 /* can't bind this again */ nullptr, /* no dynamic state yet */ nullptr,
323 /* can't propagate statistics */ nullptr, nullptr);
324 export_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
325 export_function.serialize = ExportStateAggregateSerialize;
326 export_function.deserialize = ExportStateAggregateDeserialize;
327
328 return make_uniq<BoundAggregateExpression>(args&: export_function, args: std::move(child_aggregate->children),
329 args: std::move(child_aggregate->filter), args: std::move(export_bind_data),
330 args&: child_aggregate->aggr_type);
331}
332
333ScalarFunction ExportAggregateFunction::GetFinalize() {
334 auto result = ScalarFunction("finalize", {LogicalTypeId::AGGREGATE_STATE}, LogicalTypeId::INVALID,
335 AggregateStateFinalize, BindAggregateState, nullptr, nullptr, InitFinalizeState);
336 result.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
337 result.serialize = ExportStateScalarSerialize;
338 result.deserialize = ExportStateScalarDeserialize;
339 return result;
340}
341
342ScalarFunction ExportAggregateFunction::GetCombine() {
343 auto result =
344 ScalarFunction("combine", {LogicalTypeId::AGGREGATE_STATE, LogicalTypeId::ANY}, LogicalTypeId::AGGREGATE_STATE,
345 AggregateStateCombine, BindAggregateState, nullptr, nullptr, InitCombineState);
346 result.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
347 result.serialize = ExportStateScalarSerialize;
348 result.deserialize = ExportStateScalarDeserialize;
349 return result;
350}
351
352void ExportAggregateFunction::RegisterFunction(BuiltinFunctions &set) {
353 set.AddFunction(function: ExportAggregateFunction::GetCombine());
354 set.AddFunction(function: ExportAggregateFunction::GetFinalize());
355}
356
357} // namespace duckdb
358