1 | #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" |
2 | #include "duckdb/function/scalar/generic_functions.hpp" |
3 | #include "duckdb/main/client_context.hpp" |
4 | #include "duckdb/main/database.hpp" |
5 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
6 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
7 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
8 | #include "duckdb/function/function_binder.hpp" |
9 | |
10 | namespace duckdb { |
11 | |
12 | // aggregate state export |
13 | struct ExportAggregateBindData : public FunctionData { |
14 | AggregateFunction aggr; |
15 | idx_t state_size; |
16 | |
17 | explicit ExportAggregateBindData(AggregateFunction aggr_p, idx_t state_size_p) |
18 | : aggr(std::move(aggr_p)), state_size(state_size_p) { |
19 | } |
20 | |
21 | unique_ptr<FunctionData> Copy() const override { |
22 | return make_uniq<ExportAggregateBindData>(args: aggr, args: state_size); |
23 | } |
24 | |
25 | bool Equals(const FunctionData &other_p) const override { |
26 | auto &other = other_p.Cast<ExportAggregateBindData>(); |
27 | return aggr == other.aggr && state_size == other.state_size; |
28 | } |
29 | |
30 | static ExportAggregateBindData &GetFrom(ExpressionState &state) { |
31 | auto &func_expr = state.expr.Cast<BoundFunctionExpression>(); |
32 | return func_expr.bind_info->Cast<ExportAggregateBindData>(); |
33 | } |
34 | }; |
35 | |
36 | struct CombineState : public FunctionLocalState { |
37 | idx_t state_size; |
38 | |
39 | unsafe_unique_array<data_t> state_buffer0, state_buffer1; |
40 | Vector state_vector0, state_vector1; |
41 | |
42 | explicit CombineState(idx_t state_size_p) |
43 | : state_size(state_size_p), state_buffer0(make_unsafe_uniq_array<data_t>(n: state_size_p)), |
44 | state_buffer1(make_unsafe_uniq_array<data_t>(n: state_size_p)), |
45 | state_vector0(Value::POINTER(value: CastPointerToValue(src: state_buffer0.get()))), |
46 | state_vector1(Value::POINTER(value: CastPointerToValue(src: state_buffer1.get()))) { |
47 | } |
48 | }; |
49 | |
50 | static unique_ptr<FunctionLocalState> InitCombineState(ExpressionState &state, const BoundFunctionExpression &expr, |
51 | FunctionData *bind_data_p) { |
52 | auto &bind_data = bind_data_p->Cast<ExportAggregateBindData>(); |
53 | return make_uniq<CombineState>(args&: bind_data.state_size); |
54 | } |
55 | |
56 | struct FinalizeState : public FunctionLocalState { |
57 | idx_t state_size; |
58 | unsafe_unique_array<data_t> state_buffer; |
59 | Vector addresses; |
60 | |
61 | explicit FinalizeState(idx_t state_size_p) |
62 | : state_size(state_size_p), |
63 | state_buffer(make_unsafe_uniq_array<data_t>(STANDARD_VECTOR_SIZE * AlignValue(n: state_size_p))), |
64 | addresses(LogicalType::POINTER) { |
65 | } |
66 | }; |
67 | |
68 | static unique_ptr<FunctionLocalState> InitFinalizeState(ExpressionState &state, const BoundFunctionExpression &expr, |
69 | FunctionData *bind_data_p) { |
70 | auto &bind_data = bind_data_p->Cast<ExportAggregateBindData>(); |
71 | return make_uniq<FinalizeState>(args&: bind_data.state_size); |
72 | } |
73 | |
74 | static void AggregateStateFinalize(DataChunk &input, ExpressionState &state_p, Vector &result) { |
75 | auto &bind_data = ExportAggregateBindData::GetFrom(state&: state_p); |
76 | auto &local_state = ExecuteFunctionState::GetFunctionState(state&: state_p)->Cast<FinalizeState>(); |
77 | |
78 | D_ASSERT(bind_data.state_size == bind_data.aggr.state_size()); |
79 | D_ASSERT(input.data.size() == 1); |
80 | D_ASSERT(input.data[0].GetType().id() == LogicalTypeId::AGGREGATE_STATE); |
81 | auto aligned_state_size = AlignValue(n: bind_data.state_size); |
82 | |
83 | auto state_vec_ptr = FlatVector::GetData<data_ptr_t>(vector&: local_state.addresses); |
84 | |
85 | UnifiedVectorFormat state_data; |
86 | input.data[0].ToUnifiedFormat(count: input.size(), data&: state_data); |
87 | for (idx_t i = 0; i < input.size(); i++) { |
88 | auto state_idx = state_data.sel->get_index(idx: i); |
89 | auto state_entry = UnifiedVectorFormat::GetData<string_t>(format: state_data) + state_idx; |
90 | auto target_ptr = char_ptr_cast(src: local_state.state_buffer.get()) + aligned_state_size * i; |
91 | |
92 | if (state_data.validity.RowIsValid(row_idx: state_idx)) { |
93 | D_ASSERT(state_entry->GetSize() == bind_data.state_size); |
94 | memcpy(dest: (void *)target_ptr, src: state_entry->GetData(), n: bind_data.state_size); |
95 | } else { |
96 | // create a dummy state because finalize does not understand NULLs in its input |
97 | // we put the NULL back in explicitly below |
98 | bind_data.aggr.initialize(data_ptr_cast(src: target_ptr)); |
99 | } |
100 | state_vec_ptr[i] = data_ptr_cast(src: target_ptr); |
101 | } |
102 | |
103 | AggregateInputData aggr_input_data(nullptr, Allocator::DefaultAllocator()); |
104 | bind_data.aggr.finalize(local_state.addresses, aggr_input_data, result, input.size(), 0); |
105 | |
106 | for (idx_t i = 0; i < input.size(); i++) { |
107 | auto state_idx = state_data.sel->get_index(idx: i); |
108 | if (!state_data.validity.RowIsValid(row_idx: state_idx)) { |
109 | FlatVector::SetNull(vector&: result, idx: i, is_null: true); |
110 | } |
111 | } |
112 | } |
113 | |
114 | static void AggregateStateCombine(DataChunk &input, ExpressionState &state_p, Vector &result) { |
115 | auto &bind_data = ExportAggregateBindData::GetFrom(state&: state_p); |
116 | auto &local_state = ExecuteFunctionState::GetFunctionState(state&: state_p)->Cast<CombineState>(); |
117 | |
118 | D_ASSERT(bind_data.state_size == bind_data.aggr.state_size()); |
119 | |
120 | D_ASSERT(input.data.size() == 2); |
121 | D_ASSERT(input.data[0].GetType().id() == LogicalTypeId::AGGREGATE_STATE); |
122 | D_ASSERT(input.data[0].GetType() == result.GetType()); |
123 | |
124 | if (input.data[0].GetType().InternalType() != input.data[1].GetType().InternalType()) { |
125 | throw IOException("Aggregate state combine type mismatch, expect %s, got %s" , |
126 | input.data[0].GetType().ToString(), input.data[1].GetType().ToString()); |
127 | } |
128 | |
129 | UnifiedVectorFormat state0_data, state1_data; |
130 | input.data[0].ToUnifiedFormat(count: input.size(), data&: state0_data); |
131 | input.data[1].ToUnifiedFormat(count: input.size(), data&: state1_data); |
132 | |
133 | auto result_ptr = FlatVector::GetData<string_t>(vector&: result); |
134 | |
135 | for (idx_t i = 0; i < input.size(); i++) { |
136 | auto state0_idx = state0_data.sel->get_index(idx: i); |
137 | auto state1_idx = state1_data.sel->get_index(idx: i); |
138 | |
139 | auto &state0 = UnifiedVectorFormat::GetData<string_t>(format: state0_data)[state0_idx]; |
140 | auto &state1 = UnifiedVectorFormat::GetData<string_t>(format: state1_data)[state1_idx]; |
141 | |
142 | // if both are NULL, we return NULL. If either of them is not, the result is that one |
143 | if (!state0_data.validity.RowIsValid(row_idx: state0_idx) && !state1_data.validity.RowIsValid(row_idx: state1_idx)) { |
144 | FlatVector::SetNull(vector&: result, idx: i, is_null: true); |
145 | continue; |
146 | } |
147 | if (state0_data.validity.RowIsValid(row_idx: state0_idx) && !state1_data.validity.RowIsValid(row_idx: state1_idx)) { |
148 | result_ptr[i] = |
149 | StringVector::AddStringOrBlob(vector&: result, data: const_char_ptr_cast(src: state0.GetData()), len: bind_data.state_size); |
150 | continue; |
151 | } |
152 | if (!state0_data.validity.RowIsValid(row_idx: state0_idx) && state1_data.validity.RowIsValid(row_idx: state1_idx)) { |
153 | result_ptr[i] = |
154 | StringVector::AddStringOrBlob(vector&: result, data: const_char_ptr_cast(src: state1.GetData()), len: bind_data.state_size); |
155 | continue; |
156 | } |
157 | |
158 | // we actually have to combine |
159 | if (state0.GetSize() != bind_data.state_size || state1.GetSize() != bind_data.state_size) { |
160 | throw IOException("Aggregate state size mismatch, expect %llu, got %llu and %llu" , bind_data.state_size, |
161 | state0.GetSize(), state1.GetSize()); |
162 | } |
163 | |
164 | memcpy(dest: local_state.state_buffer0.get(), src: state0.GetData(), n: bind_data.state_size); |
165 | memcpy(dest: local_state.state_buffer1.get(), src: state1.GetData(), n: bind_data.state_size); |
166 | |
167 | AggregateInputData aggr_input_data(nullptr, Allocator::DefaultAllocator()); |
168 | bind_data.aggr.combine(local_state.state_vector0, local_state.state_vector1, aggr_input_data, 1); |
169 | |
170 | result_ptr[i] = StringVector::AddStringOrBlob(vector&: result, data: const_char_ptr_cast(src: local_state.state_buffer1.get()), |
171 | len: bind_data.state_size); |
172 | } |
173 | } |
174 | |
175 | static unique_ptr<FunctionData> BindAggregateState(ClientContext &context, ScalarFunction &bound_function, |
176 | vector<unique_ptr<Expression>> &arguments) { |
177 | |
178 | // grab the aggregate type and bind the aggregate again |
179 | |
180 | // the aggregate name and types are in the logical type of the aggregate state, make sure its sane |
181 | auto &arg_return_type = arguments[0]->return_type; |
182 | for (auto &arg_type : bound_function.arguments) { |
183 | arg_type = arg_return_type; |
184 | } |
185 | |
186 | if (arg_return_type.id() != LogicalTypeId::AGGREGATE_STATE) { |
187 | throw BinderException("Can only FINALIZE aggregate state, not %s" , arg_return_type.ToString()); |
188 | } |
189 | // combine |
190 | if (arguments.size() == 2 && arguments[0]->return_type != arguments[1]->return_type && |
191 | arguments[1]->return_type.id() != LogicalTypeId::BLOB) { |
192 | throw BinderException("Cannot COMBINE aggregate states from different functions, %s <> %s" , |
193 | arguments[0]->return_type.ToString(), arguments[1]->return_type.ToString()); |
194 | } |
195 | |
196 | // following error states are only reachable when someone messes up creating the state_type |
197 | // which is impossible from SQL |
198 | |
199 | auto state_type = AggregateStateType::GetStateType(type: arg_return_type); |
200 | |
201 | // now we can look up the function in the catalog again and bind it |
202 | auto &func = Catalog::GetSystemCatalog(context).GetEntry(context, type: CatalogType::SCALAR_FUNCTION_ENTRY, |
203 | DEFAULT_SCHEMA, name: state_type.function_name); |
204 | if (func.type != CatalogType::AGGREGATE_FUNCTION_ENTRY) { |
205 | throw InternalException("Could not find aggregate %s" , state_type.function_name); |
206 | } |
207 | auto &aggr = func.Cast<AggregateFunctionCatalogEntry>(); |
208 | |
209 | string error; |
210 | |
211 | FunctionBinder function_binder(context); |
212 | idx_t best_function = |
213 | function_binder.BindFunction(name: aggr.name, functions&: aggr.functions, arguments: state_type.bound_argument_types, error); |
214 | if (best_function == DConstants::INVALID_INDEX) { |
215 | throw InternalException("Could not re-bind exported aggregate %s: %s" , state_type.function_name, error); |
216 | } |
217 | auto bound_aggr = aggr.functions.GetFunctionByOffset(offset: best_function); |
218 | if (bound_aggr.bind) { |
219 | // FIXME: this is really hacky |
220 | // but the aggregate state export needs a rework around how it handles more complex aggregates anyway |
221 | vector<unique_ptr<Expression>> args; |
222 | args.reserve(n: state_type.bound_argument_types.size()); |
223 | for (auto &arg_type : state_type.bound_argument_types) { |
224 | args.push_back(x: make_uniq<BoundConstantExpression>(args: Value(arg_type))); |
225 | } |
226 | auto bind_info = bound_aggr.bind(context, bound_aggr, args); |
227 | if (bind_info) { |
228 | throw BinderException("Aggregate function with bind info not supported yet in aggregate state export" ); |
229 | } |
230 | } |
231 | |
232 | if (bound_aggr.return_type != state_type.return_type || bound_aggr.arguments != state_type.bound_argument_types) { |
233 | throw InternalException("Type mismatch for exported aggregate %s" , state_type.function_name); |
234 | } |
235 | |
236 | if (bound_function.name == "finalize" ) { |
237 | bound_function.return_type = bound_aggr.return_type; |
238 | } else { |
239 | D_ASSERT(bound_function.name == "combine" ); |
240 | bound_function.return_type = arg_return_type; |
241 | } |
242 | |
243 | return make_uniq<ExportAggregateBindData>(args&: bound_aggr, args: bound_aggr.state_size()); |
244 | } |
245 | |
246 | static void ExportAggregateFinalize(Vector &state, AggregateInputData &aggr_input_data, Vector &result, idx_t count, |
247 | idx_t offset) { |
248 | D_ASSERT(offset == 0); |
249 | auto &bind_data = aggr_input_data.bind_data->Cast<ExportAggregateFunctionBindData>(); |
250 | auto state_size = bind_data.aggregate->function.state_size(); |
251 | auto blob_ptr = FlatVector::GetData<string_t>(vector&: result); |
252 | auto addresses_ptr = FlatVector::GetData<data_ptr_t>(vector&: state); |
253 | for (idx_t row_idx = 0; row_idx < count; row_idx++) { |
254 | auto data_ptr = addresses_ptr[row_idx]; |
255 | blob_ptr[row_idx] = StringVector::AddStringOrBlob(vector&: result, data: const_char_ptr_cast(src: data_ptr), len: state_size); |
256 | } |
257 | } |
258 | |
259 | ExportAggregateFunctionBindData::ExportAggregateFunctionBindData(unique_ptr<Expression> aggregate_p) { |
260 | D_ASSERT(aggregate_p->type == ExpressionType::BOUND_AGGREGATE); |
261 | aggregate = unique_ptr_cast<Expression, BoundAggregateExpression>(src: std::move(aggregate_p)); |
262 | } |
263 | |
264 | unique_ptr<FunctionData> ExportAggregateFunctionBindData::Copy() const { |
265 | return make_uniq<ExportAggregateFunctionBindData>(args: aggregate->Copy()); |
266 | } |
267 | |
268 | bool ExportAggregateFunctionBindData::Equals(const FunctionData &other_p) const { |
269 | auto &other = other_p.Cast<ExportAggregateFunctionBindData>(); |
270 | return aggregate->Equals(other: *other.aggregate); |
271 | } |
272 | |
273 | static void ExportStateAggregateSerialize(FieldWriter &writer, const FunctionData *bind_data_p, |
274 | const AggregateFunction &function) { |
275 | throw NotImplementedException("FIXME: export state serialize" ); |
276 | } |
277 | static unique_ptr<FunctionData> ExportStateAggregateDeserialize(PlanDeserializationState &state, FieldReader &reader, |
278 | AggregateFunction &bound_function) { |
279 | throw NotImplementedException("FIXME: export state deserialize" ); |
280 | } |
281 | |
282 | static void ExportStateScalarSerialize(FieldWriter &writer, const FunctionData *bind_data_p, |
283 | const ScalarFunction &function) { |
284 | throw NotImplementedException("FIXME: export state serialize" ); |
285 | } |
286 | static unique_ptr<FunctionData> ExportStateScalarDeserialize(PlanDeserializationState &state, FieldReader &reader, |
287 | ScalarFunction &bound_function) { |
288 | throw NotImplementedException("FIXME: export state deserialize" ); |
289 | } |
290 | |
291 | unique_ptr<BoundAggregateExpression> |
292 | ExportAggregateFunction::Bind(unique_ptr<BoundAggregateExpression> child_aggregate) { |
293 | auto &bound_function = child_aggregate->function; |
294 | if (!bound_function.combine) { |
295 | throw BinderException("Cannot use EXPORT_STATE for non-combinable function %s" , bound_function.name); |
296 | } |
297 | if (bound_function.bind) { |
298 | throw BinderException("Cannot use EXPORT_STATE on aggregate functions with custom binders" ); |
299 | } |
300 | if (bound_function.destructor) { |
301 | throw BinderException("Cannot use EXPORT_STATE on aggregate functions with custom destructors" ); |
302 | } |
303 | // this should be required |
304 | D_ASSERT(bound_function.state_size); |
305 | D_ASSERT(bound_function.finalize); |
306 | |
307 | D_ASSERT(child_aggregate->function.return_type.id() != LogicalTypeId::INVALID); |
308 | #ifdef DEBUG |
309 | for (auto &arg_type : child_aggregate->function.arguments) { |
310 | D_ASSERT(arg_type.id() != LogicalTypeId::INVALID); |
311 | } |
312 | #endif |
313 | auto export_bind_data = make_uniq<ExportAggregateFunctionBindData>(args: child_aggregate->Copy()); |
314 | aggregate_state_t state_type(child_aggregate->function.name, child_aggregate->function.return_type, |
315 | child_aggregate->function.arguments); |
316 | auto return_type = LogicalType::AGGREGATE_STATE(state_type: std::move(state_type)); |
317 | |
318 | auto export_function = |
319 | AggregateFunction("aggregate_state_export_" + bound_function.name, bound_function.arguments, return_type, |
320 | bound_function.state_size, bound_function.initialize, bound_function.update, |
321 | bound_function.combine, ExportAggregateFinalize, bound_function.simple_update, |
322 | /* can't bind this again */ nullptr, /* no dynamic state yet */ nullptr, |
323 | /* can't propagate statistics */ nullptr, nullptr); |
324 | export_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; |
325 | export_function.serialize = ExportStateAggregateSerialize; |
326 | export_function.deserialize = ExportStateAggregateDeserialize; |
327 | |
328 | return make_uniq<BoundAggregateExpression>(args&: export_function, args: std::move(child_aggregate->children), |
329 | args: std::move(child_aggregate->filter), args: std::move(export_bind_data), |
330 | args&: child_aggregate->aggr_type); |
331 | } |
332 | |
333 | ScalarFunction ExportAggregateFunction::GetFinalize() { |
334 | auto result = ScalarFunction("finalize" , {LogicalTypeId::AGGREGATE_STATE}, LogicalTypeId::INVALID, |
335 | AggregateStateFinalize, BindAggregateState, nullptr, nullptr, InitFinalizeState); |
336 | result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; |
337 | result.serialize = ExportStateScalarSerialize; |
338 | result.deserialize = ExportStateScalarDeserialize; |
339 | return result; |
340 | } |
341 | |
342 | ScalarFunction ExportAggregateFunction::GetCombine() { |
343 | auto result = |
344 | ScalarFunction("combine" , {LogicalTypeId::AGGREGATE_STATE, LogicalTypeId::ANY}, LogicalTypeId::AGGREGATE_STATE, |
345 | AggregateStateCombine, BindAggregateState, nullptr, nullptr, InitCombineState); |
346 | result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; |
347 | result.serialize = ExportStateScalarSerialize; |
348 | result.deserialize = ExportStateScalarDeserialize; |
349 | return result; |
350 | } |
351 | |
352 | void ExportAggregateFunction::RegisterFunction(BuiltinFunctions &set) { |
353 | set.AddFunction(function: ExportAggregateFunction::GetCombine()); |
354 | set.AddFunction(function: ExportAggregateFunction::GetFinalize()); |
355 | } |
356 | |
357 | } // namespace duckdb |
358 | |