1 | #include "duckdb/execution/operator/aggregate/aggregate_object.hpp" |
2 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
3 | #include "duckdb/planner/expression/bound_window_expression.hpp" |
4 | |
5 | namespace duckdb { |
6 | |
7 | AggregateObject::AggregateObject(AggregateFunction function, FunctionData *bind_data, idx_t child_count, |
8 | idx_t payload_size, AggregateType aggr_type, PhysicalType return_type, |
9 | Expression *filter) |
10 | : function(std::move(function)), |
11 | bind_data_wrapper(bind_data ? make_shared<FunctionDataWrapper>(args: bind_data->Copy()) : nullptr), |
12 | child_count(child_count), payload_size(payload_size), aggr_type(aggr_type), return_type(return_type), |
13 | filter(filter) { |
14 | } |
15 | |
16 | AggregateObject::AggregateObject(BoundAggregateExpression *aggr) |
17 | : AggregateObject(aggr->function, aggr->bind_info.get(), aggr->children.size(), |
18 | AlignValue(n: aggr->function.state_size()), aggr->aggr_type, aggr->return_type.InternalType(), |
19 | aggr->filter.get()) { |
20 | } |
21 | |
22 | AggregateObject::AggregateObject(BoundWindowExpression &window) |
23 | : AggregateObject(*window.aggregate, window.bind_info.get(), window.children.size(), |
24 | AlignValue(n: window.aggregate->state_size()), AggregateType::NON_DISTINCT, |
25 | window.return_type.InternalType(), window.filter_expr.get()) { |
26 | } |
27 | |
28 | vector<AggregateObject> AggregateObject::CreateAggregateObjects(const vector<BoundAggregateExpression *> &bindings) { |
29 | vector<AggregateObject> aggregates; |
30 | aggregates.reserve(n: aggregates.size()); |
31 | for (auto &binding : bindings) { |
32 | aggregates.emplace_back(args: binding); |
33 | } |
34 | return aggregates; |
35 | } |
36 | |
37 | AggregateFilterData::AggregateFilterData(ClientContext &context, Expression &filter_expr, |
38 | const vector<LogicalType> &payload_types) |
39 | : filter_executor(context, &filter_expr), true_sel(STANDARD_VECTOR_SIZE) { |
40 | if (payload_types.empty()) { |
41 | return; |
42 | } |
43 | filtered_payload.Initialize(allocator&: Allocator::Get(context), types: payload_types); |
44 | } |
45 | |
46 | idx_t AggregateFilterData::ApplyFilter(DataChunk &payload) { |
47 | filtered_payload.Reset(); |
48 | |
49 | auto count = filter_executor.SelectExpression(input&: payload, sel&: true_sel); |
50 | filtered_payload.Slice(other&: payload, sel: true_sel, count); |
51 | return count; |
52 | } |
53 | |
54 | AggregateFilterDataSet::AggregateFilterDataSet() { |
55 | } |
56 | |
57 | void AggregateFilterDataSet::Initialize(ClientContext &context, const vector<AggregateObject> &aggregates, |
58 | const vector<LogicalType> &payload_types) { |
59 | bool has_filters = false; |
60 | for (auto &aggregate : aggregates) { |
61 | if (aggregate.filter) { |
62 | has_filters = true; |
63 | break; |
64 | } |
65 | } |
66 | if (!has_filters) { |
67 | // no filters: nothing to do |
68 | return; |
69 | } |
70 | filter_data.resize(new_size: aggregates.size()); |
71 | for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { |
72 | auto &aggr = aggregates[aggr_idx]; |
73 | if (aggr.filter) { |
74 | filter_data[aggr_idx] = make_uniq<AggregateFilterData>(args&: context, args&: *aggr.filter, args: payload_types); |
75 | } |
76 | } |
77 | } |
78 | |
79 | AggregateFilterData &AggregateFilterDataSet::GetFilterData(idx_t aggr_idx) { |
80 | D_ASSERT(aggr_idx < filter_data.size()); |
81 | D_ASSERT(filter_data[aggr_idx]); |
82 | return *filter_data[aggr_idx]; |
83 | } |
84 | } // namespace duckdb |
85 | |