1#include "duckdb/execution/operator/aggregate/aggregate_object.hpp"
2#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
3#include "duckdb/planner/expression/bound_window_expression.hpp"
4
5namespace duckdb {
6
7AggregateObject::AggregateObject(AggregateFunction function, FunctionData *bind_data, idx_t child_count,
8 idx_t payload_size, AggregateType aggr_type, PhysicalType return_type,
9 Expression *filter)
10 : function(std::move(function)),
11 bind_data_wrapper(bind_data ? make_shared<FunctionDataWrapper>(args: bind_data->Copy()) : nullptr),
12 child_count(child_count), payload_size(payload_size), aggr_type(aggr_type), return_type(return_type),
13 filter(filter) {
14}
15
16AggregateObject::AggregateObject(BoundAggregateExpression *aggr)
17 : AggregateObject(aggr->function, aggr->bind_info.get(), aggr->children.size(),
18 AlignValue(n: aggr->function.state_size()), aggr->aggr_type, aggr->return_type.InternalType(),
19 aggr->filter.get()) {
20}
21
22AggregateObject::AggregateObject(BoundWindowExpression &window)
23 : AggregateObject(*window.aggregate, window.bind_info.get(), window.children.size(),
24 AlignValue(n: window.aggregate->state_size()), AggregateType::NON_DISTINCT,
25 window.return_type.InternalType(), window.filter_expr.get()) {
26}
27
28vector<AggregateObject> AggregateObject::CreateAggregateObjects(const vector<BoundAggregateExpression *> &bindings) {
29 vector<AggregateObject> aggregates;
30 aggregates.reserve(n: aggregates.size());
31 for (auto &binding : bindings) {
32 aggregates.emplace_back(args: binding);
33 }
34 return aggregates;
35}
36
37AggregateFilterData::AggregateFilterData(ClientContext &context, Expression &filter_expr,
38 const vector<LogicalType> &payload_types)
39 : filter_executor(context, &filter_expr), true_sel(STANDARD_VECTOR_SIZE) {
40 if (payload_types.empty()) {
41 return;
42 }
43 filtered_payload.Initialize(allocator&: Allocator::Get(context), types: payload_types);
44}
45
46idx_t AggregateFilterData::ApplyFilter(DataChunk &payload) {
47 filtered_payload.Reset();
48
49 auto count = filter_executor.SelectExpression(input&: payload, sel&: true_sel);
50 filtered_payload.Slice(other&: payload, sel: true_sel, count);
51 return count;
52}
53
54AggregateFilterDataSet::AggregateFilterDataSet() {
55}
56
57void AggregateFilterDataSet::Initialize(ClientContext &context, const vector<AggregateObject> &aggregates,
58 const vector<LogicalType> &payload_types) {
59 bool has_filters = false;
60 for (auto &aggregate : aggregates) {
61 if (aggregate.filter) {
62 has_filters = true;
63 break;
64 }
65 }
66 if (!has_filters) {
67 // no filters: nothing to do
68 return;
69 }
70 filter_data.resize(new_size: aggregates.size());
71 for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) {
72 auto &aggr = aggregates[aggr_idx];
73 if (aggr.filter) {
74 filter_data[aggr_idx] = make_uniq<AggregateFilterData>(args&: context, args&: *aggr.filter, args: payload_types);
75 }
76 }
77}
78
79AggregateFilterData &AggregateFilterDataSet::GetFilterData(idx_t aggr_idx) {
80 D_ASSERT(aggr_idx < filter_data.size());
81 D_ASSERT(filter_data[aggr_idx]);
82 return *filter_data[aggr_idx];
83}
84} // namespace duckdb
85